Compare commits
163 Commits
version/20
...
version/20
Author | SHA1 | Date | |
---|---|---|---|
09b02e1aec | |||
451a9aaf01 | |||
eaee7cb562 | |||
a010c91a52 | |||
709194330f | |||
5914bbf173 | |||
5e9166f859 | |||
35b8ef6592 | |||
772a939f17 | |||
24971801cf | |||
43aebe8cb2 | |||
19cfc87c84 | |||
f920f183c8 | |||
97f979c81e | |||
e61411d396 | |||
c4f985f542 | |||
302dee7ab2 | |||
83c12ad483 | |||
4224fd5c6f | |||
597ce1eb42 | |||
5ef385f0bb | |||
cda4be3d47 | |||
8cdf22fc94 | |||
6efc7578ef | |||
4e2457560d | |||
2ddf122d27 | |||
a24651437a | |||
30bb7acb17 | |||
7859145138 | |||
8a8aafec81 | |||
deebdf2bcc | |||
4982c4abcb | |||
1486f90077 | |||
f4988bc45e | |||
8abc9cc031 | |||
534689895c | |||
8a0dd6be24 | |||
65d2eed82d | |||
e450e7b107 | |||
552ddda909 | |||
bafeff7306 | |||
6791436302 | |||
7eda794070 | |||
e3129c1067 | |||
ff481ba6e7 | |||
a106bad2db | |||
3a1c311d02 | |||
6465333f4f | |||
b761659227 | |||
9321c355f8 | |||
86c8e79ea1 | |||
8916b1f8ab | |||
41fcf2aba6 | |||
87e72b08a9 | |||
b2fcd42e3c | |||
fc1b47a80f | |||
af14e3502e | |||
a2faa5ceb5 | |||
63a19a1381 | |||
b472dcb7e7 | |||
6303909031 | |||
4bdc06865b | |||
2ee48cd039 | |||
893d5f452b | |||
340a9bc8ee | |||
cb3d9f83f1 | |||
4ba55aa8e9 | |||
bab6f501ec | |||
7327939684 | |||
ffb0135f06 | |||
ee0ddc3d17 | |||
5dd979d66c | |||
a9bd34f3c5 | |||
db316b59c5 | |||
6209714f87 | |||
1ed2bddba7 | |||
26b35c9b7b | |||
86a9271f75 | |||
402ed9bd20 | |||
68a0684569 | |||
bd2e453218 | |||
1f31c63e57 | |||
480410efa2 | |||
e9bfee52ed | |||
326b574d54 | |||
0a7abcf2ad | |||
9e5019881e | |||
8071750681 | |||
f2f0931904 | |||
a91204e5b9 | |||
b14c22cbff | |||
b3e40c6aed | |||
873aa4bb22 | |||
c1ea78c422 | |||
3c8bbc2621 | |||
42a9979d91 | |||
b7f94df4d9 | |||
4143d3fe28 | |||
f95c06b76f | |||
e3e9178ccc | |||
b694816e7b | |||
e046000f36 | |||
edb5caae9b | |||
02d27651f3 | |||
44cd4d847d | |||
472256794d | |||
cbb6887983 | |||
317e9ec605 | |||
ada2a16412 | |||
61f6b0f122 | |||
6a3f7e45cf | |||
2b78c4ba86 | |||
680ef641fb | |||
2b5504ff63 | |||
f8a6aa3250 | |||
6c23fc4b2b | |||
639c2f5c2e | |||
e44632f9a0 | |||
3f2ce34468 | |||
426cef998f | |||
8ddb62ed0f | |||
572f6d4ea0 | |||
8db68410c6 | |||
caa3c3de32 | |||
23b5ca761a | |||
f1b9021e3e | |||
99c62af89e | |||
8ae50814fe | |||
2e2b491ec7 | |||
ac432e78e2 | |||
83ac42ac43 | |||
4bd1cd127b | |||
2eb5a5cc76 | |||
75051687e6 | |||
7e316b5fc2 | |||
5594ad0b36 | |||
ea097afeae | |||
b77b4b5c80 | |||
f8dc7f48f2 | |||
692e75b057 | |||
02771683a6 | |||
40404ff41d | |||
fdd5211253 | |||
85a417d22e | |||
66c530ea06 | |||
347c3793fc | |||
cf78c89830 | |||
20c738c384 | |||
4f54ce6afb | |||
f0d7edb963 | |||
e42ad8db93 | |||
e917e756cc | |||
b4963bec76 | |||
0d23796989 | |||
d0ceafe79e | |||
f2023a7af2 | |||
31d597005f | |||
62dc86be7b | |||
7aa8e35f87 | |||
60b95271eb | |||
382b0e8941 | |||
3b068610b9 | |||
9a8f62f42e |
@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 2021.12.1-rc1
|
||||
current_version = 2021.12.1-rc4
|
||||
tag = True
|
||||
commit = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)\-?(?P<release>.*)
|
||||
|
15
.github/workflows/ci-main.yml
vendored
15
.github/workflows/ci-main.yml
vendored
@ -28,6 +28,7 @@ jobs:
|
||||
- isort
|
||||
- bandit
|
||||
- pyright
|
||||
- pending-migrations
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@ -88,9 +89,11 @@ jobs:
|
||||
run: |
|
||||
# Copy current, latest config to local
|
||||
cp authentik/lib/default.yml local.env.yml
|
||||
cp -R .github ..
|
||||
cp -R scripts ..
|
||||
git checkout $(git describe --abbrev=0 --match 'version/*')
|
||||
git checkout $GITHUB_HEAD_REF -- .github
|
||||
git checkout $GITHUB_HEAD_REF -- scripts
|
||||
rm -rf .github/ scripts/
|
||||
mv ../.github ../scripts .
|
||||
- name: prepare
|
||||
env:
|
||||
INSTALL: ${{ steps.cache-pipenv.outputs.cache-hit }}
|
||||
@ -104,6 +107,7 @@ jobs:
|
||||
run: |
|
||||
set -x
|
||||
git fetch
|
||||
git reset --hard HEAD
|
||||
git checkout $GITHUB_HEAD_REF
|
||||
pipenv sync --dev
|
||||
- name: prepare
|
||||
@ -219,7 +223,7 @@ jobs:
|
||||
testspace [e2e]unittest.xml --link=codecov
|
||||
- if: ${{ always() }}
|
||||
uses: codecov/codecov-action@v2
|
||||
build:
|
||||
ci-core-mark:
|
||||
needs:
|
||||
- lint
|
||||
- test-migrations
|
||||
@ -228,6 +232,11 @@ jobs:
|
||||
- test-integration
|
||||
- test-e2e
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo mark
|
||||
build:
|
||||
needs: ci-core-mark
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
8
.github/workflows/ci-outpost.yml
vendored
8
.github/workflows/ci-outpost.yml
vendored
@ -30,10 +30,16 @@ jobs:
|
||||
-w /app \
|
||||
golangci/golangci-lint:v1.39.0 \
|
||||
golangci-lint run -v --timeout 200s
|
||||
ci-outpost-mark:
|
||||
needs:
|
||||
- lint-golint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo mark
|
||||
build:
|
||||
timeout-minutes: 120
|
||||
needs:
|
||||
- lint-golint
|
||||
- ci-outpost-mark
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
8
.github/workflows/ci-web.yml
vendored
8
.github/workflows/ci-web.yml
vendored
@ -65,12 +65,18 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
npm run lit-analyse
|
||||
build:
|
||||
ci-web-mark:
|
||||
needs:
|
||||
- lint-eslint
|
||||
- lint-prettier
|
||||
- lint-lit-analyse
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo mark
|
||||
build:
|
||||
needs:
|
||||
- ci-web-mark
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-node@v2
|
||||
|
27
.github/workflows/release-publish.yml
vendored
27
.github/workflows/release-publish.yml
vendored
@ -30,14 +30,14 @@ jobs:
|
||||
with:
|
||||
push: ${{ github.event_name == 'release' }}
|
||||
tags: |
|
||||
beryju/authentik:2021.12.1-rc1,
|
||||
beryju/authentik:2021.12.1-rc4,
|
||||
beryju/authentik:latest,
|
||||
ghcr.io/goauthentik/server:2021.12.1-rc1,
|
||||
ghcr.io/goauthentik/server:2021.12.1-rc4,
|
||||
ghcr.io/goauthentik/server:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
context: .
|
||||
- name: Building Docker Image (stable)
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.12.1-rc1', 'rc') }}
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.12.1-rc4', 'rc') }}
|
||||
run: |
|
||||
docker pull beryju/authentik:latest
|
||||
docker tag beryju/authentik:latest beryju/authentik:stable
|
||||
@ -78,14 +78,14 @@ jobs:
|
||||
with:
|
||||
push: ${{ github.event_name == 'release' }}
|
||||
tags: |
|
||||
beryju/authentik-${{ matrix.type }}:2021.12.1-rc1,
|
||||
beryju/authentik-${{ matrix.type }}:2021.12.1-rc4,
|
||||
beryju/authentik-${{ matrix.type }}:latest,
|
||||
ghcr.io/goauthentik/${{ matrix.type }}:2021.12.1-rc1,
|
||||
ghcr.io/goauthentik/${{ matrix.type }}:2021.12.1-rc4,
|
||||
ghcr.io/goauthentik/${{ matrix.type }}:latest
|
||||
file: ${{ matrix.type }}.Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- name: Building Docker Image (stable)
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.12.1-rc1', 'rc') }}
|
||||
if: ${{ github.event_name == 'release' && !contains('2021.12.1-rc4', 'rc') }}
|
||||
run: |
|
||||
docker pull beryju/authentik-${{ matrix.type }}:latest
|
||||
docker tag beryju/authentik-${{ matrix.type }}:latest beryju/authentik-${{ matrix.type }}:stable
|
||||
@ -114,16 +114,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Setup Node.js environment
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: '16'
|
||||
- name: Build web api client and web ui
|
||||
- name: Get static files from docker image
|
||||
run: |
|
||||
export NODE_ENV=production
|
||||
cd web
|
||||
npm i
|
||||
npm run build
|
||||
docker pull ghcr.io/goauthentik/server:latest
|
||||
container=$(docker container create ghcr.io/goauthentik/server:latest)
|
||||
docker cp ${container}:web/ .
|
||||
- name: Create a Sentry.io release
|
||||
uses: getsentry/action-release@v1
|
||||
if: ${{ github.event_name == 'release' }}
|
||||
@ -133,7 +128,7 @@ jobs:
|
||||
SENTRY_PROJECT: authentik
|
||||
SENTRY_URL: https://sentry.beryju.org
|
||||
with:
|
||||
version: authentik@2021.12.1-rc1
|
||||
version: authentik@2021.12.1-rc4
|
||||
environment: beryjuorg-prod
|
||||
sourcemaps: './web/dist'
|
||||
url_prefix: '~/static/dist'
|
||||
|
19
Dockerfile
19
Dockerfile
@ -1,5 +1,5 @@
|
||||
# Stage 1: Lock python dependencies
|
||||
FROM docker.io/python:3.9-slim-bullseye as locker
|
||||
FROM docker.io/python:3.10.1-slim-bullseye as locker
|
||||
|
||||
COPY ./Pipfile /app/
|
||||
COPY ./Pipfile.lock /app/
|
||||
@ -28,19 +28,15 @@ ENV NODE_ENV=production
|
||||
RUN cd /work/web && npm i && npm run build
|
||||
|
||||
# Stage 4: Build go proxy
|
||||
FROM docker.io/golang:1.17.3-bullseye AS builder
|
||||
FROM docker.io/golang:1.17.5-bullseye AS builder
|
||||
|
||||
WORKDIR /work
|
||||
|
||||
COPY --from=web-builder /work/web/robots.txt /work/web/robots.txt
|
||||
COPY --from=web-builder /work/web/security.txt /work/web/security.txt
|
||||
COPY --from=web-builder /work/web/dist/ /work/web/dist/
|
||||
COPY --from=web-builder /work/web/authentik/ /work/web/authentik/
|
||||
COPY --from=website-builder /work/website/help/ /work/website/help/
|
||||
|
||||
COPY ./cmd /work/cmd
|
||||
COPY ./web/static.go /work/web/static.go
|
||||
COPY ./website/static.go /work/website/static.go
|
||||
COPY ./internal /work/internal
|
||||
COPY ./go.mod /work/go.mod
|
||||
COPY ./go.sum /work/go.sum
|
||||
@ -48,7 +44,7 @@ COPY ./go.sum /work/go.sum
|
||||
RUN go build -o /work/authentik ./cmd/server/main.go
|
||||
|
||||
# Stage 5: Run
|
||||
FROM docker.io/python:3.9-slim-bullseye
|
||||
FROM docker.io/python:3.10.1-slim-bullseye
|
||||
|
||||
WORKDIR /
|
||||
COPY --from=locker /app/requirements.txt /
|
||||
@ -62,14 +58,16 @@ RUN apt-get update && \
|
||||
curl ca-certificates gnupg git runit libpq-dev \
|
||||
postgresql-client build-essential libxmlsec1-dev \
|
||||
pkg-config libmaxminddb0 && \
|
||||
pip install lxml==4.6.4 --no-cache-dir && \
|
||||
export C_INCLUDE_PATH=/usr/local/lib/python3.10/site-packages/lxml/includes && \
|
||||
pip install -r /requirements.txt --no-cache-dir && \
|
||||
apt-get remove --purge -y build-essential git && \
|
||||
apt-get autoremove --purge -y && \
|
||||
apt-get clean && \
|
||||
rm -rf /tmp/* /var/lib/apt/lists/* /var/tmp/ && \
|
||||
adduser --system --no-create-home --uid 1000 --group --home /authentik authentik && \
|
||||
mkdir /backups && \
|
||||
chown authentik:authentik /backups
|
||||
mkdir -p /backups /certs /media && \
|
||||
chown authentik:authentik /backups /certs /media
|
||||
|
||||
COPY ./authentik/ /authentik
|
||||
COPY ./pyproject.toml /
|
||||
@ -78,6 +76,9 @@ COPY ./tests /tests
|
||||
COPY ./manage.py /
|
||||
COPY ./lifecycle/ /lifecycle
|
||||
COPY --from=builder /work/authentik /authentik-proxy
|
||||
COPY --from=web-builder /work/web/dist/ /web/dist/
|
||||
COPY --from=web-builder /work/web/authentik/ /web/authentik/
|
||||
COPY --from=website-builder /work/website/help/ /website/help/
|
||||
|
||||
USER authentik
|
||||
|
||||
|
8
Makefile
8
Makefile
@ -68,7 +68,7 @@ gen-outpost:
|
||||
docker run \
|
||||
--rm -v ${PWD}:/local \
|
||||
--user ${UID}:${GID} \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
openapitools/openapi-generator-cli:v5.2.1 generate \
|
||||
-i /local/schema.yml \
|
||||
-g go \
|
||||
-o /local/api \
|
||||
@ -84,6 +84,9 @@ migrate:
|
||||
run:
|
||||
go run -v cmd/server/main.go
|
||||
|
||||
web-watch:
|
||||
cd web && npm run watch
|
||||
|
||||
web: web-lint-fix web-lint web-extract
|
||||
|
||||
web-lint-fix:
|
||||
@ -113,3 +116,6 @@ ci-bandit:
|
||||
|
||||
ci-pyright:
|
||||
pyright e2e lifecycle
|
||||
|
||||
ci-pending-migrations:
|
||||
./manage.py makemigrations --check
|
||||
|
5
Pipfile
5
Pipfile
@ -32,7 +32,8 @@ geoip2 = "*"
|
||||
gunicorn = "*"
|
||||
kubernetes = "==v19.15.0"
|
||||
ldap3 = "*"
|
||||
lxml = "*"
|
||||
# 4.7.0 and later remove `lxml-version.h` which is required by xmlsec
|
||||
lxml = "==4.6.5"
|
||||
packaging = "*"
|
||||
psycopg2-binary = "*"
|
||||
pycryptodome = "*"
|
||||
@ -49,6 +50,8 @@ urllib3 = {extras = ["secure"],version = "*"}
|
||||
uvicorn = {extras = ["standard"],version = "*"}
|
||||
webauthn = "*"
|
||||
xmlsec = "*"
|
||||
flower = "*"
|
||||
wsproto = "*"
|
||||
|
||||
[dev-packages]
|
||||
bandit = "*"
|
||||
|
727
Pipfile.lock
generated
727
Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,3 @@
|
||||
"""authentik"""
|
||||
__version__ = "2021.12.1-rc1"
|
||||
__version__ = "2021.12.1-rc4"
|
||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
|
||||
|
@ -54,7 +54,7 @@ def clear_update_notifications():
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def update_latest_version(self: MonitoredTask):
|
||||
"""Update latest version info"""
|
||||
if CONFIG.y_bool("disable_update_check"):
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Groups API Viewset"""
|
||||
from json import loads
|
||||
|
||||
from django.db.models.query import QuerySet
|
||||
from django_filters.filters import ModelMultipleChoiceFilter
|
||||
from django_filters.filters import CharFilter, ModelMultipleChoiceFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from rest_framework.fields import CharField, JSONField
|
||||
from rest_framework.serializers import ListSerializer, ModelSerializer
|
||||
from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework_guardian.filters import ObjectPermissionsFilter
|
||||
|
||||
@ -62,6 +64,13 @@ class GroupSerializer(ModelSerializer):
|
||||
class GroupFilter(FilterSet):
|
||||
"""Filter for groups"""
|
||||
|
||||
attributes = CharFilter(
|
||||
field_name="attributes",
|
||||
lookup_expr="",
|
||||
label="Attributes",
|
||||
method="filter_attributes",
|
||||
)
|
||||
|
||||
members_by_username = ModelMultipleChoiceFilter(
|
||||
field_name="users__username",
|
||||
to_field_name="username",
|
||||
@ -72,10 +81,28 @@ class GroupFilter(FilterSet):
|
||||
queryset=User.objects.all(),
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def filter_attributes(self, queryset, name, value):
|
||||
"""Filter attributes by query args"""
|
||||
try:
|
||||
value = loads(value)
|
||||
except ValueError:
|
||||
raise ValidationError(detail="filter: failed to parse JSON")
|
||||
if not isinstance(value, dict):
|
||||
raise ValidationError(detail="filter: value must be key:value mapping")
|
||||
qs = {}
|
||||
for key, _value in value.items():
|
||||
qs[f"attributes__{key}"] = _value
|
||||
try:
|
||||
_ = len(queryset.filter(**qs))
|
||||
return queryset.filter(**qs)
|
||||
except ValueError:
|
||||
return queryset
|
||||
|
||||
class Meta:
|
||||
|
||||
model = Group
|
||||
fields = ["name", "is_superuser", "members_by_pk", "members_by_username"]
|
||||
fields = ["name", "is_superuser", "members_by_pk", "attributes", "members_by_username"]
|
||||
|
||||
|
||||
class GroupViewSet(UsedByMixin, ModelViewSet):
|
||||
|
@ -233,7 +233,11 @@ class UsersFilter(FilterSet):
|
||||
qs = {}
|
||||
for key, _value in value.items():
|
||||
qs[f"attributes__{key}"] = _value
|
||||
return queryset.filter(**qs)
|
||||
try:
|
||||
_ = len(queryset.filter(**qs))
|
||||
return queryset.filter(**qs)
|
||||
except ValueError:
|
||||
return queryset
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
@ -314,7 +318,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
|
||||
name=username,
|
||||
attributes={USER_ATTRIBUTE_SA: True, USER_ATTRIBUTE_TOKEN_EXPIRING: False},
|
||||
)
|
||||
if create_group:
|
||||
if create_group and self.request.user.has_perm("authentik_core.add_group"):
|
||||
group = Group.objects.create(
|
||||
name=username,
|
||||
)
|
||||
|
@ -65,4 +65,6 @@ def structlog_add_request_id(logger: Logger, method_name: str, event_dict: dict)
|
||||
"""If threadlocal has authentik defined, add request_id to log"""
|
||||
if hasattr(LOCAL, "authentik"):
|
||||
event_dict.update(LOCAL.authentik)
|
||||
if hasattr(LOCAL, "authentik_task"):
|
||||
event_dict.update(LOCAL.authentik_task)
|
||||
return event_dict
|
||||
|
@ -25,7 +25,6 @@ from structlog.stdlib import get_logger
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.signals import password_changed
|
||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.lib.models import CreatedUpdatedModel, DomainlessURLValidator, SerializerModel
|
||||
@ -203,7 +202,7 @@ class Provider(SerializerModel):
|
||||
name = models.TextField()
|
||||
|
||||
authorization_flow = models.ForeignKey(
|
||||
Flow,
|
||||
"authentik_flows.Flow",
|
||||
on_delete=models.CASCADE,
|
||||
help_text=_("Flow used when authorizing this provider."),
|
||||
related_name="provider_authorization",
|
||||
@ -263,7 +262,7 @@ class Application(PolicyBindingModel):
|
||||
it is returned as-is"""
|
||||
if not self.meta_icon:
|
||||
return None
|
||||
if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"):
|
||||
if "://" in self.meta_icon.name or self.meta_icon.name.startswith("/static"):
|
||||
return self.meta_icon.name
|
||||
return self.meta_icon.url
|
||||
|
||||
@ -324,7 +323,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
|
||||
|
||||
authentication_flow = models.ForeignKey(
|
||||
Flow,
|
||||
"authentik_flows.Flow",
|
||||
blank=True,
|
||||
null=True,
|
||||
default=None,
|
||||
@ -333,7 +332,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
related_name="source_authentication",
|
||||
)
|
||||
enrollment_flow = models.ForeignKey(
|
||||
Flow,
|
||||
"authentik_flows.Flow",
|
||||
blank=True,
|
||||
null=True,
|
||||
default=None,
|
||||
|
@ -29,7 +29,7 @@ LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def clean_expired_models(self: MonitoredTask):
|
||||
"""Remove expired objects"""
|
||||
messages = []
|
||||
@ -69,7 +69,7 @@ def should_backup() -> bool:
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def backup_database(self: MonitoredTask): # pragma: no cover
|
||||
"""Database backup"""
|
||||
self.result_timeout_hours = 25
|
||||
|
@ -20,6 +20,7 @@ from authentik.api.decorators import permission_required
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import PassiveSerializer
|
||||
from authentik.crypto.builder import CertificateBuilder
|
||||
from authentik.crypto.managed import MANAGED_KEY
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.models import Event, EventAction
|
||||
|
||||
@ -141,9 +142,11 @@ class CertificateKeyPairFilter(FilterSet):
|
||||
class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
"""CertificateKeyPair Viewset"""
|
||||
|
||||
queryset = CertificateKeyPair.objects.exclude(managed__isnull=False)
|
||||
queryset = CertificateKeyPair.objects.exclude(managed=MANAGED_KEY)
|
||||
serializer_class = CertificateKeyPairSerializer
|
||||
filterset_class = CertificateKeyPairFilter
|
||||
ordering = ["name"]
|
||||
search_fields = ["name"]
|
||||
|
||||
@permission_required(None, ["authentik_crypto.add_certificatekeypair"])
|
||||
@extend_schema(
|
||||
@ -189,7 +192,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
secret=certificate,
|
||||
type="certificate",
|
||||
).from_http(request)
|
||||
if "download" in request._request.GET:
|
||||
if "download" in request.query_params:
|
||||
# Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
||||
response = HttpResponse(
|
||||
certificate.certificate_data, content_type="application/x-pem-file"
|
||||
@ -220,7 +223,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
|
||||
secret=certificate,
|
||||
type="private_key",
|
||||
).from_http(request)
|
||||
if "download" in request._request.GET:
|
||||
if "download" in request.query_params:
|
||||
# Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
||||
response = HttpResponse(certificate.key_data, content_type="application/x-pem-file")
|
||||
response[
|
||||
|
@ -13,3 +13,4 @@ class AuthentikCryptoConfig(AppConfig):
|
||||
|
||||
def ready(self):
|
||||
import_module("authentik.crypto.managed")
|
||||
import_module("authentik.crypto.tasks")
|
||||
|
10
authentik/crypto/settings.py
Normal file
10
authentik/crypto/settings.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Crypto task Settings"""
|
||||
from celery.schedules import crontab
|
||||
|
||||
CELERY_BEAT_SCHEDULE = {
|
||||
"crypto_certificate_discovery": {
|
||||
"task": "authentik.crypto.tasks.certificate_discovery",
|
||||
"schedule": crontab(minute="*/5"),
|
||||
"options": {"queue": "authentik_scheduled"},
|
||||
},
|
||||
}
|
73
authentik/crypto/tasks.py
Normal file
73
authentik/crypto/tasks.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Crypto tasks"""
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
TaskResultStatus,
|
||||
prefill_task,
|
||||
)
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
MANAGED_DISCOVERED = "goauthentik.io/crypto/discovered/%s"
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task
|
||||
def certificate_discovery(self: MonitoredTask):
|
||||
"""Discover and update certificates form the filesystem"""
|
||||
certs = {}
|
||||
private_keys = {}
|
||||
discovered = 0
|
||||
for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True):
|
||||
path = Path(file)
|
||||
if not path.exists():
|
||||
continue
|
||||
if path.is_dir():
|
||||
continue
|
||||
# Support certbot's directory structure
|
||||
if path.name in ["fullchain.pem", "privkey.pem"]:
|
||||
cert_name = path.parent.name
|
||||
else:
|
||||
cert_name = path.name.replace(path.suffix, "")
|
||||
try:
|
||||
with open(path, "r+", encoding="utf-8") as _file:
|
||||
body = _file.read()
|
||||
if "BEGIN RSA PRIVATE KEY" in body:
|
||||
private_keys[cert_name] = body
|
||||
else:
|
||||
certs[cert_name] = body
|
||||
except OSError as exc:
|
||||
LOGGER.warning("Failed to open file", exc=exc, file=path)
|
||||
discovered += 1
|
||||
for name, cert_data in certs.items():
|
||||
cert = CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % name).first()
|
||||
if not cert:
|
||||
cert = CertificateKeyPair(
|
||||
name=name,
|
||||
managed=MANAGED_DISCOVERED % name,
|
||||
)
|
||||
dirty = False
|
||||
if cert.certificate_data != cert_data:
|
||||
cert.certificate_data = cert_data
|
||||
dirty = True
|
||||
if name in private_keys:
|
||||
if cert.key_data == private_keys[name]:
|
||||
cert.key_data = private_keys[name]
|
||||
dirty = True
|
||||
if dirty:
|
||||
cert.save()
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.SUCCESSFUL,
|
||||
messages=[_("Successfully imported %(count)d files." % {"count": discovered})],
|
||||
)
|
||||
)
|
@ -1,5 +1,7 @@
|
||||
"""Crypto tests"""
|
||||
import datetime
|
||||
from os import makedirs
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
@ -9,6 +11,8 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_cert,
|
||||
from authentik.crypto.api import CertificateKeyPairSerializer
|
||||
from authentik.crypto.builder import CertificateBuilder
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.crypto.tasks import MANAGED_DISCOVERED, certificate_discovery
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.generators import generate_key
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
|
||||
@ -163,3 +167,33 @@ class TestCrypto(APITestCase):
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
def test_discovery(self):
|
||||
"""Test certificate discovery"""
|
||||
builder = CertificateBuilder()
|
||||
builder.common_name = "test-cert"
|
||||
with self.assertRaises(ValueError):
|
||||
builder.save()
|
||||
builder.build(
|
||||
subject_alt_names=[],
|
||||
validity_days=3,
|
||||
)
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
with open(f"{temp_dir}/foo.pem", "w+", encoding="utf-8") as _cert:
|
||||
_cert.write(builder.certificate)
|
||||
with open(f"{temp_dir}/foo.key", "w+", encoding="utf-8") as _key:
|
||||
_key.write(builder.private_key)
|
||||
makedirs(f"{temp_dir}/foo.bar", exist_ok=True)
|
||||
with open(f"{temp_dir}/foo.bar/fullchain.pem", "w+", encoding="utf-8") as _cert:
|
||||
_cert.write(builder.certificate)
|
||||
with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key:
|
||||
_key.write(builder.private_key)
|
||||
with CONFIG.patch("cert_discovery_dir", temp_dir):
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
certificate_discovery() # pylint: disable=no-value-for-parameter
|
||||
self.assertTrue(
|
||||
CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % "foo").exists()
|
||||
)
|
||||
self.assertTrue(
|
||||
CertificateKeyPair.objects.filter(managed=MANAGED_DISCOVERED % "foo.bar").exists()
|
||||
)
|
||||
|
@ -112,30 +112,6 @@ class TaskInfo:
|
||||
cache.set(key, self, timeout=timeout_hours * 60 * 60)
|
||||
|
||||
|
||||
def prefill_task():
|
||||
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
|
||||
|
||||
def inner_wrap(func):
|
||||
status = TaskInfo.by_name(func.__name__)
|
||||
if status:
|
||||
return func
|
||||
TaskInfo(
|
||||
task_name=func.__name__,
|
||||
task_description=func.__doc__,
|
||||
result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]),
|
||||
task_call_module=func.__module__,
|
||||
task_call_func=func.__name__,
|
||||
# We don't have real values for these attributes but they cannot be null
|
||||
start_timestamp=default_timer(),
|
||||
finish_timestamp=default_timer(),
|
||||
finish_time=datetime.now(),
|
||||
).save(86400)
|
||||
LOGGER.debug("prefilled task", task_name=func.__name__)
|
||||
return func
|
||||
|
||||
return inner_wrap
|
||||
|
||||
|
||||
class MonitoredTask(Task):
|
||||
"""Task which can save its state to the cache"""
|
||||
|
||||
@ -210,5 +186,21 @@ class MonitoredTask(Task):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
for task in TaskInfo.all().values():
|
||||
task.set_prom_metrics()
|
||||
def prefill_task(func):
|
||||
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
|
||||
status = TaskInfo.by_name(func.__name__)
|
||||
if status:
|
||||
return func
|
||||
TaskInfo(
|
||||
task_name=func.__name__,
|
||||
task_description=func.__doc__,
|
||||
result=TaskResult(TaskResultStatus.UNKNOWN, messages=[_("Task has not been run yet.")]),
|
||||
task_call_module=func.__module__,
|
||||
task_call_func=func.__name__,
|
||||
# We don't have real values for these attributes but they cannot be null
|
||||
start_timestamp=default_timer(),
|
||||
finish_timestamp=default_timer(),
|
||||
finish_time=datetime.now(),
|
||||
).save(86400)
|
||||
LOGGER.debug("prefilled task", task_name=func.__name__)
|
||||
return func
|
||||
|
46
authentik/flows/migrations/0020_flowtoken.py
Normal file
46
authentik/flows/migrations/0020_flowtoken.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Generated by Django 3.2.9 on 2021-12-05 13:50
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_core", "0018_auto_20210330_1345_squashed_0028_alter_token_intent"),
|
||||
(
|
||||
"authentik_flows",
|
||||
"0019_alter_flow_background_squashed_0024_alter_flow_compatibility_mode",
|
||||
),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="FlowToken",
|
||||
fields=[
|
||||
(
|
||||
"token_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.token",
|
||||
),
|
||||
),
|
||||
("_plan", models.TextField()),
|
||||
(
|
||||
"flow",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="authentik_flows.flow"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Flow Token",
|
||||
"verbose_name_plural": "Flow Tokens",
|
||||
},
|
||||
bases=("authentik_core.token",),
|
||||
),
|
||||
]
|
@ -1,4 +1,6 @@
|
||||
"""Flow models"""
|
||||
from base64 import b64decode, b64encode
|
||||
from pickle import dumps, loads # nosec
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
@ -9,11 +11,13 @@ from model_utils.managers import InheritanceManager
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Token
|
||||
from authentik.core.types import UserSettingSerializer
|
||||
from authentik.lib.models import InheritanceForeignKey, SerializerModel
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.flows.planner import FlowPlan
|
||||
from authentik.flows.stage import StageView
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -260,3 +264,30 @@ class ConfigurableStage(models.Model):
|
||||
class Meta:
|
||||
|
||||
abstract = True
|
||||
|
||||
|
||||
class FlowToken(Token):
|
||||
"""Subclass of a standard Token, stores the currently active flow plan upon creation.
|
||||
Can be used to later resume a flow."""
|
||||
|
||||
flow = models.ForeignKey(Flow, on_delete=models.CASCADE)
|
||||
_plan = models.TextField()
|
||||
|
||||
@staticmethod
|
||||
def pickle(plan) -> str:
|
||||
"""Pickle into string"""
|
||||
data = dumps(plan)
|
||||
return b64encode(data).decode()
|
||||
|
||||
@property
|
||||
def plan(self) -> "FlowPlan":
|
||||
"""Load Flow plan from pickled version"""
|
||||
return loads(b64decode(self._plan.encode())) # nosec
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Flow Token {super().__str__()}"
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = _("Flow Token")
|
||||
verbose_name_plural = _("Flow Tokens")
|
||||
|
@ -24,6 +24,9 @@ PLAN_CONTEXT_SSO = "is_sso"
|
||||
PLAN_CONTEXT_REDIRECT = "redirect"
|
||||
PLAN_CONTEXT_APPLICATION = "application"
|
||||
PLAN_CONTEXT_SOURCE = "source"
|
||||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||
# was restored.
|
||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||
GAUGE_FLOWS_CACHED = UpdatingGauge(
|
||||
"authentik_flows_cached",
|
||||
"Cached flows",
|
||||
@ -123,7 +126,7 @@ class FlowPlanner:
|
||||
) -> FlowPlan:
|
||||
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
|
||||
and return ordered list"""
|
||||
with Hub.current.start_span(op="flow.planner.plan") as span:
|
||||
with Hub.current.start_span(op="flow.planner.plan", description=self.flow.slug) as span:
|
||||
span: Span
|
||||
span.set_data("flow", self.flow)
|
||||
span.set_data("request", request)
|
||||
@ -178,7 +181,8 @@ class FlowPlanner:
|
||||
"""Build flow plan by checking each stage in their respective
|
||||
order and checking the applied policies"""
|
||||
with Hub.current.start_span(
|
||||
op="flow.planner.build_plan"
|
||||
op="flow.planner.build_plan",
|
||||
description=self.flow.slug,
|
||||
) as span, HIST_FLOWS_PLAN_TIME.labels(flow_slug=self.flow.slug).time():
|
||||
span: Span
|
||||
span.set_data("flow", self.flow)
|
||||
|
@ -19,6 +19,8 @@ from drf_spectacular.utils import OpenApiParameter, PolymorphicProxySerializer,
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.views import APIView
|
||||
from sentry_sdk import capture_exception
|
||||
from sentry_sdk.api import set_tag
|
||||
from sentry_sdk.hub import Hub
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.models import USER_ATTRIBUTE_DEBUG
|
||||
@ -34,8 +36,16 @@ from authentik.flows.challenge import (
|
||||
WithUserInfoChallenge,
|
||||
)
|
||||
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
|
||||
from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage
|
||||
from authentik.flows.models import (
|
||||
ConfigurableStage,
|
||||
Flow,
|
||||
FlowDesignation,
|
||||
FlowStageBinding,
|
||||
FlowToken,
|
||||
Stage,
|
||||
)
|
||||
from authentik.flows.planner import (
|
||||
PLAN_CONTEXT_IS_RESTORED,
|
||||
PLAN_CONTEXT_PENDING_USER,
|
||||
PLAN_CONTEXT_REDIRECT,
|
||||
FlowPlan,
|
||||
@ -53,7 +63,9 @@ NEXT_ARG_NAME = "next"
|
||||
SESSION_KEY_PLAN = "authentik_flows_plan"
|
||||
SESSION_KEY_APPLICATION_PRE = "authentik_flows_application_pre"
|
||||
SESSION_KEY_GET = "authentik_flows_get"
|
||||
SESSION_KEY_POST = "authentik_flows_post"
|
||||
SESSION_KEY_HISTORY = "authentik_flows_history"
|
||||
QS_KEY_TOKEN = "flow_token" # nosec
|
||||
|
||||
|
||||
def challenge_types():
|
||||
@ -116,6 +128,7 @@ class FlowExecutorView(APIView):
|
||||
super().setup(request, flow_slug=flow_slug)
|
||||
self.flow = get_object_or_404(Flow.objects.select_related(), slug=flow_slug)
|
||||
self._logger = get_logger().bind(flow_slug=flow_slug)
|
||||
set_tag("authentik.flow", self.flow.slug)
|
||||
|
||||
def handle_invalid_flow(self, exc: BaseException) -> HttpResponse:
|
||||
"""When a flow is non-applicable check if user is on the correct domain"""
|
||||
@ -126,71 +139,100 @@ class FlowExecutorView(APIView):
|
||||
message = exc.__doc__ if exc.__doc__ else str(exc)
|
||||
return self.stage_invalid(error_message=message)
|
||||
|
||||
def _check_flow_token(self, get_params: QueryDict):
|
||||
"""Check if the user is using a flow token to restore a plan"""
|
||||
tokens = FlowToken.filter_not_expired(key=get_params[QS_KEY_TOKEN])
|
||||
if not tokens.exists():
|
||||
return False
|
||||
token: FlowToken = tokens.first()
|
||||
try:
|
||||
plan = token.plan
|
||||
except (AttributeError, EOFError, ImportError, IndexError) as exc:
|
||||
LOGGER.warning("f(exec): Failed to restore token plan", exc=exc)
|
||||
finally:
|
||||
token.delete()
|
||||
if not isinstance(plan, FlowPlan):
|
||||
return None
|
||||
plan.context[PLAN_CONTEXT_IS_RESTORED] = True
|
||||
self._logger.debug("f(exec): restored flow plan from token", plan=plan)
|
||||
return plan
|
||||
|
||||
# pylint: disable=unused-argument, too-many-return-statements
|
||||
def dispatch(self, request: HttpRequest, flow_slug: str) -> HttpResponse:
|
||||
# Early check if there's an active Plan for the current session
|
||||
if SESSION_KEY_PLAN in self.request.session:
|
||||
self.plan = self.request.session[SESSION_KEY_PLAN]
|
||||
if self.plan.flow_pk != self.flow.pk.hex:
|
||||
self._logger.warning(
|
||||
"f(exec): Found existing plan for other flow, deleting plan",
|
||||
)
|
||||
# Existing plan is deleted from session and instance
|
||||
self.plan = None
|
||||
self.cancel()
|
||||
self._logger.debug("f(exec): Continuing existing plan")
|
||||
with Hub.current.start_span(
|
||||
op="flow.executor.dispatch", description=self.flow.slug
|
||||
) as span:
|
||||
span.set_data("authentik Flow", self.flow.slug)
|
||||
get_params = QueryDict(request.GET.get("query", ""))
|
||||
if QS_KEY_TOKEN in get_params:
|
||||
plan = self._check_flow_token(get_params)
|
||||
if plan:
|
||||
self.request.session[SESSION_KEY_PLAN] = plan
|
||||
# Early check if there's an active Plan for the current session
|
||||
if SESSION_KEY_PLAN in self.request.session:
|
||||
self.plan = self.request.session[SESSION_KEY_PLAN]
|
||||
if self.plan.flow_pk != self.flow.pk.hex:
|
||||
self._logger.warning(
|
||||
"f(exec): Found existing plan for other flow, deleting plan",
|
||||
)
|
||||
# Existing plan is deleted from session and instance
|
||||
self.plan = None
|
||||
self.cancel()
|
||||
self._logger.debug("f(exec): Continuing existing plan")
|
||||
|
||||
# Don't check session again as we've either already loaded the plan or we need to plan
|
||||
if not self.plan:
|
||||
request.session[SESSION_KEY_HISTORY] = []
|
||||
self._logger.debug("f(exec): No active Plan found, initiating planner")
|
||||
# Don't check session again as we've either already loaded the plan or we need to plan
|
||||
if not self.plan:
|
||||
request.session[SESSION_KEY_HISTORY] = []
|
||||
self._logger.debug("f(exec): No active Plan found, initiating planner")
|
||||
try:
|
||||
self.plan = self._initiate_plan()
|
||||
except FlowNonApplicableException as exc:
|
||||
self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
|
||||
return to_stage_response(self.request, self.handle_invalid_flow(exc))
|
||||
except EmptyFlowException as exc:
|
||||
self._logger.warning("f(exec): Flow is empty", exc=exc)
|
||||
# To match behaviour with loading an empty flow plan from cache,
|
||||
# we don't show an error message here, but rather call _flow_done()
|
||||
return self._flow_done()
|
||||
# Initial flow request, check if we have an upstream query string passed in
|
||||
request.session[SESSION_KEY_GET] = get_params
|
||||
# We don't save the Plan after getting the next stage
|
||||
# as it hasn't been successfully passed yet
|
||||
try:
|
||||
self.plan = self._initiate_plan()
|
||||
except FlowNonApplicableException as exc:
|
||||
self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
|
||||
return to_stage_response(self.request, self.handle_invalid_flow(exc))
|
||||
except EmptyFlowException as exc:
|
||||
self._logger.warning("f(exec): Flow is empty", exc=exc)
|
||||
# To match behaviour with loading an empty flow plan from cache,
|
||||
# we don't show an error message here, but rather call _flow_done()
|
||||
# This is the first time we actually access any attribute on the selected plan
|
||||
# if the cached plan is from an older version, it might have different attributes
|
||||
# in which case we just delete the plan and invalidate everything
|
||||
next_binding = self.plan.next(self.request)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
self._logger.warning(
|
||||
"f(exec): found incompatible flow plan, invalidating run", exc=exc
|
||||
)
|
||||
keys = cache.keys("flow_*")
|
||||
cache.delete_many(keys)
|
||||
return self.stage_invalid()
|
||||
if not next_binding:
|
||||
self._logger.debug("f(exec): no more stages, flow is done.")
|
||||
return self._flow_done()
|
||||
# Initial flow request, check if we have an upstream query string passed in
|
||||
request.session[SESSION_KEY_GET] = QueryDict(request.GET.get("query", ""))
|
||||
# We don't save the Plan after getting the next stage
|
||||
# as it hasn't been successfully passed yet
|
||||
try:
|
||||
# This is the first time we actually access any attribute on the selected plan
|
||||
# if the cached plan is from an older version, it might have different attributes
|
||||
# in which case we just delete the plan and invalidate everything
|
||||
next_binding = self.plan.next(self.request)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc)
|
||||
keys = cache.keys("flow_*")
|
||||
cache.delete_many(keys)
|
||||
return self.stage_invalid()
|
||||
if not next_binding:
|
||||
self._logger.debug("f(exec): no more stages, flow is done.")
|
||||
return self._flow_done()
|
||||
self.current_binding = next_binding
|
||||
self.current_stage = next_binding.stage
|
||||
self._logger.debug(
|
||||
"f(exec): Current stage",
|
||||
current_stage=self.current_stage,
|
||||
flow_slug=self.flow.slug,
|
||||
)
|
||||
try:
|
||||
stage_cls = self.current_stage.type
|
||||
except NotImplementedError as exc:
|
||||
self._logger.debug("Error getting stage type", exc=exc)
|
||||
return self.stage_invalid()
|
||||
self.current_stage_view = stage_cls(self)
|
||||
self.current_stage_view.args = self.args
|
||||
self.current_stage_view.kwargs = self.kwargs
|
||||
self.current_stage_view.request = request
|
||||
try:
|
||||
return super().dispatch(request)
|
||||
except InvalidStageError as exc:
|
||||
return self.stage_invalid(str(exc))
|
||||
self.current_binding = next_binding
|
||||
self.current_stage = next_binding.stage
|
||||
self._logger.debug(
|
||||
"f(exec): Current stage",
|
||||
current_stage=self.current_stage,
|
||||
flow_slug=self.flow.slug,
|
||||
)
|
||||
try:
|
||||
stage_cls = self.current_stage.type
|
||||
except NotImplementedError as exc:
|
||||
self._logger.debug("Error getting stage type", exc=exc)
|
||||
return self.stage_invalid()
|
||||
self.current_stage_view = stage_cls(self)
|
||||
self.current_stage_view.args = self.args
|
||||
self.current_stage_view.kwargs = self.kwargs
|
||||
self.current_stage_view.request = request
|
||||
try:
|
||||
return super().dispatch(request)
|
||||
except InvalidStageError as exc:
|
||||
return self.stage_invalid(str(exc))
|
||||
|
||||
def handle_exception(self, exc: Exception) -> HttpResponse:
|
||||
"""Handle exception in stage execution"""
|
||||
@ -232,8 +274,15 @@ class FlowExecutorView(APIView):
|
||||
stage=self.current_stage,
|
||||
)
|
||||
try:
|
||||
stage_response = self.current_stage_view.get(request, *args, **kwargs)
|
||||
return to_stage_response(request, stage_response)
|
||||
with Hub.current.start_span(
|
||||
op="flow.executor.stage",
|
||||
description=class_to_path(self.current_stage_view.__class__),
|
||||
) as span:
|
||||
span.set_data("Method", "GET")
|
||||
span.set_data("authentik Stage", self.current_stage_view)
|
||||
span.set_data("authentik Flow", self.flow.slug)
|
||||
stage_response = self.current_stage_view.get(request, *args, **kwargs)
|
||||
return to_stage_response(request, stage_response)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
return self.handle_exception(exc)
|
||||
|
||||
@ -269,8 +318,15 @@ class FlowExecutorView(APIView):
|
||||
stage=self.current_stage,
|
||||
)
|
||||
try:
|
||||
stage_response = self.current_stage_view.post(request, *args, **kwargs)
|
||||
return to_stage_response(request, stage_response)
|
||||
with Hub.current.start_span(
|
||||
op="flow.executor.stage",
|
||||
description=class_to_path(self.current_stage_view.__class__),
|
||||
) as span:
|
||||
span.set_data("Method", "POST")
|
||||
span.set_data("authentik Stage", self.current_stage_view)
|
||||
span.set_data("authentik Flow", self.flow.slug)
|
||||
stage_response = self.current_stage_view.post(request, *args, **kwargs)
|
||||
return to_stage_response(request, stage_response)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
return self.handle_exception(exc)
|
||||
|
||||
|
@ -87,9 +87,7 @@ class FlowInspectorView(APIView):
|
||||
@extend_schema(
|
||||
responses={
|
||||
200: FlowInspectionSerializer(),
|
||||
400: OpenApiResponse(
|
||||
description="No flow plan in session."
|
||||
), # This error can be raised by the email stage
|
||||
400: OpenApiResponse(description="No flow plan in session."),
|
||||
},
|
||||
request=OpenApiTypes.NONE,
|
||||
operation_id="flows_inspector_get",
|
||||
@ -106,7 +104,10 @@ class FlowInspectorView(APIView):
|
||||
if SESSION_KEY_PLAN in request.session:
|
||||
current_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||
else:
|
||||
current_plan = request.session[SESSION_KEY_HISTORY][-1]
|
||||
try:
|
||||
current_plan = request.session[SESSION_KEY_HISTORY][-1]
|
||||
except KeyError:
|
||||
return Response(status=400)
|
||||
is_completed = True
|
||||
current_serializer = FlowInspectorPlanSerializer(
|
||||
instance=current_plan, context={"request": request}
|
||||
|
@ -20,7 +20,6 @@ web:
|
||||
listen: 0.0.0.0:9000
|
||||
listen_tls: 0.0.0.0:9443
|
||||
listen_metrics: 0.0.0.0:9300
|
||||
load_local_files: false
|
||||
outpost_port_offset: 0
|
||||
|
||||
redis:
|
||||
@ -47,6 +46,7 @@ error_reporting:
|
||||
enabled: false
|
||||
environment: customer
|
||||
send_pii: false
|
||||
sample_rate: 0.5
|
||||
|
||||
# Global email settings
|
||||
email:
|
||||
@ -82,3 +82,4 @@ default_user_change_email: true
|
||||
default_user_change_username: true
|
||||
|
||||
gdpr_compliance: true
|
||||
cert_discovery_dir: /certs
|
||||
|
@ -68,9 +68,9 @@ class DomainlessURLValidator(URLValidator):
|
||||
)
|
||||
self.schemes = ["http", "https", "blank"] + list(self.schemes)
|
||||
|
||||
def __call__(self, value):
|
||||
def __call__(self, value: str):
|
||||
# Check if the scheme is valid.
|
||||
scheme = value.split("://")[0].lower()
|
||||
if scheme not in self.schemes:
|
||||
value = "default" + value
|
||||
return super().__call__(value)
|
||||
super().__call__(value)
|
||||
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
||||
|
||||
from django.http import HttpRequest
|
||||
from requests.sessions import Session
|
||||
from sentry_sdk.hub import Hub
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import ENV_GIT_HASH_KEY, __version__
|
||||
@ -52,6 +53,12 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
|
||||
fake_ip=fake_ip,
|
||||
)
|
||||
return None
|
||||
# Update sentry scope to include correct IP
|
||||
user = Hub.current.scope._user
|
||||
if not user:
|
||||
user = {}
|
||||
user["ip_address"] = fake_ip
|
||||
Hub.current.scope.set_user(user)
|
||||
return fake_ip
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from authentik.managed.manager import ObjectManager
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def managed_reconcile(self: MonitoredTask):
|
||||
"""Run ObjectManager to ensure objects are up-to-date"""
|
||||
try:
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Outpost API Views"""
|
||||
from dacite.core import from_dict
|
||||
from dacite.exceptions import DaciteError
|
||||
from django_filters.filters import ModelMultipleChoiceFilter
|
||||
from django_filters.filterset import FilterSet
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, DateTimeField
|
||||
@ -99,16 +101,30 @@ class OutpostHealthSerializer(PassiveSerializer):
|
||||
version_outdated = BooleanField(read_only=True)
|
||||
|
||||
|
||||
class OutpostFilter(FilterSet):
|
||||
"""Filter for Outposts"""
|
||||
|
||||
providers_by_pk = ModelMultipleChoiceFilter(
|
||||
field_name="providers",
|
||||
queryset=Provider.objects.all(),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
|
||||
model = Outpost
|
||||
fields = {
|
||||
"providers": ["isnull"],
|
||||
"name": ["iexact", "icontains"],
|
||||
"service_connection__name": ["iexact", "icontains"],
|
||||
}
|
||||
|
||||
|
||||
class OutpostViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Outpost Viewset"""
|
||||
|
||||
queryset = Outpost.objects.all()
|
||||
serializer_class = OutpostSerializer
|
||||
filterset_fields = {
|
||||
"providers": ["isnull"],
|
||||
"name": ["iexact", "icontains"],
|
||||
"service_connection__name": ["iexact", "icontains"],
|
||||
}
|
||||
filterset_class = OutpostFilter
|
||||
search_fields = [
|
||||
"name",
|
||||
"providers__name",
|
||||
|
@ -19,8 +19,9 @@ class AuthentikOutpostConfig(AppConfig):
|
||||
import_module("authentik.outposts.signals")
|
||||
import_module("authentik.outposts.managed")
|
||||
try:
|
||||
from authentik.outposts.tasks import outpost_local_connection
|
||||
from authentik.outposts.tasks import outpost_controller_all, outpost_local_connection
|
||||
|
||||
outpost_local_connection.delay()
|
||||
outpost_controller_all.delay()
|
||||
except ProgrammingError:
|
||||
pass
|
||||
|
@ -9,7 +9,7 @@ from dacite import from_dict
|
||||
from dacite.data import Data
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from prometheus_client import Gauge
|
||||
from structlog.stdlib import get_logger
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.channels import AuthJsonConsumer
|
||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||
@ -23,8 +23,6 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
|
||||
["outpost", "uid", "version"],
|
||||
)
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class WebsocketMessageInstruction(IntEnum):
|
||||
"""Commands which can be triggered over Websocket"""
|
||||
@ -51,6 +49,7 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
||||
|
||||
outpost: Optional[Outpost] = None
|
||||
logger: BoundLogger
|
||||
|
||||
last_uid: Optional[str] = None
|
||||
|
||||
@ -59,11 +58,20 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||
def connect(self):
|
||||
super().connect()
|
||||
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
||||
outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid)
|
||||
if not outpost.exists():
|
||||
outpost = (
|
||||
get_objects_for_user(self.user, "authentik_outposts.view_outpost")
|
||||
.filter(pk=uuid)
|
||||
.first()
|
||||
)
|
||||
if not outpost:
|
||||
raise DenyConnection()
|
||||
self.accept()
|
||||
self.outpost = outpost.first()
|
||||
self.logger = get_logger().bind(outpost=outpost)
|
||||
try:
|
||||
self.accept()
|
||||
except RuntimeError as exc:
|
||||
self.logger.warning("runtime error during accept", exc=exc)
|
||||
raise DenyConnection()
|
||||
self.outpost = outpost
|
||||
self.last_uid = self.channel_name
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@ -78,9 +86,8 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||
uid=self.last_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).dec()
|
||||
LOGGER.debug(
|
||||
self.logger.debug(
|
||||
"removed outpost instance from cache",
|
||||
outpost=self.outpost,
|
||||
instance_uuid=self.last_uid,
|
||||
)
|
||||
|
||||
@ -103,9 +110,8 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||
uid=self.last_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).inc()
|
||||
LOGGER.debug(
|
||||
self.logger.debug(
|
||||
"added outpost instance to cache",
|
||||
outpost=self.outpost,
|
||||
instance_uuid=self.last_uid,
|
||||
)
|
||||
self.first_msg = True
|
||||
|
@ -24,6 +24,8 @@ class DockerController(BaseController):
|
||||
|
||||
def __init__(self, outpost: Outpost, connection: DockerServiceConnection) -> None:
|
||||
super().__init__(outpost, connection)
|
||||
if outpost.managed == MANAGED_OUTPOST:
|
||||
return
|
||||
try:
|
||||
self.client = connection.client()
|
||||
except ServiceConnectionInvalid as exc:
|
||||
@ -225,12 +227,14 @@ class DockerController(BaseController):
|
||||
raise ControllerException(str(exc)) from exc
|
||||
|
||||
def down(self):
|
||||
if self.outpost.managed != MANAGED_OUTPOST:
|
||||
if self.outpost.managed == MANAGED_OUTPOST:
|
||||
return
|
||||
try:
|
||||
container, _ = self._get_container()
|
||||
if container.status == "running":
|
||||
self.logger.info("Stopping container.")
|
||||
container.kill()
|
||||
self.logger.info("Removing container.")
|
||||
container.remove(force=True)
|
||||
except DockerException as exc:
|
||||
raise ControllerException(str(exc)) from exc
|
||||
|
@ -401,6 +401,7 @@ class Outpost(ManagedModel):
|
||||
user = users.first()
|
||||
user.attributes[USER_ATTRIBUTE_SA] = True
|
||||
user.attributes[USER_ATTRIBUTE_CAN_OVERRIDE_IP] = True
|
||||
user.name = f"Outpost {self.name} Service-Account"
|
||||
user.save()
|
||||
if should_create_user:
|
||||
self.build_user_permissions(user)
|
||||
|
@ -76,7 +76,7 @@ def outpost_service_connection_state(connection_pk: Any):
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def outpost_service_connection_monitor(self: MonitoredTask):
|
||||
"""Regularly check the state of Outpost Service Connections"""
|
||||
connections = OutpostServiceConnection.objects.all()
|
||||
@ -105,9 +105,12 @@ def outpost_controller(
|
||||
logs = []
|
||||
if from_cache:
|
||||
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
|
||||
LOGGER.debug("Getting outpost from cache to delete")
|
||||
else:
|
||||
outpost: Outpost = Outpost.objects.filter(pk=outpost_pk).first()
|
||||
LOGGER.debug("Getting outpost from DB")
|
||||
if not outpost:
|
||||
LOGGER.warning("No outpost")
|
||||
return
|
||||
self.set_uid(slugify(outpost.name))
|
||||
try:
|
||||
@ -126,7 +129,7 @@ def outpost_controller(
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def outpost_token_ensurer(self: MonitoredTask):
|
||||
"""Periodically ensure that all Outposts have valid Service Accounts
|
||||
and Tokens"""
|
||||
|
@ -90,7 +90,8 @@ class PolicyEngine:
|
||||
def build(self) -> "PolicyEngine":
|
||||
"""Build wrapper which monitors performance"""
|
||||
with Hub.current.start_span(
|
||||
op="policy.engine.build"
|
||||
op="policy.engine.build",
|
||||
description=self.__pbm,
|
||||
) as span, HIST_POLICIES_BUILD_TIME.labels(
|
||||
object_name=self.__pbm,
|
||||
object_type=f"{self.__pbm._meta.app_label}.{self.__pbm._meta.model_name}",
|
||||
|
@ -69,8 +69,8 @@ class Migration(migrations.Migration):
|
||||
("authentik.stages.user_logout", "authentik Stages.User Logout"),
|
||||
("authentik.stages.user_write", "authentik Stages.User Write"),
|
||||
("authentik.tenants", "authentik Tenants"),
|
||||
("authentik.core", "authentik Core"),
|
||||
("authentik.managed", "authentik Managed"),
|
||||
("authentik.core", "authentik Core"),
|
||||
],
|
||||
default="",
|
||||
help_text="Match events created by selected application. When left empty, all applications are matched.",
|
||||
|
@ -11,6 +11,8 @@ from authentik.flows.planner import PLAN_CONTEXT_SSO
|
||||
from authentik.lib.expression.evaluator import BaseEvaluator
|
||||
from authentik.lib.utils.http import get_client_ip
|
||||
from authentik.policies.exceptions import PolicyException
|
||||
from authentik.policies.models import Policy, PolicyBinding
|
||||
from authentik.policies.process import PolicyProcess
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -31,6 +33,7 @@ class PolicyEvaluator(BaseEvaluator):
|
||||
self._context["ak_logger"] = get_logger(policy_name)
|
||||
self._context["ak_message"] = self.expr_func_message
|
||||
self._context["ak_user_has_authenticator"] = self.expr_func_user_has_authenticator
|
||||
self._context["ak_call_policy"] = self.expr_func_call_policy
|
||||
self._context["ip_address"] = ip_address
|
||||
self._context["ip_network"] = ip_network
|
||||
self._filename = policy_name or "PolicyEvaluator"
|
||||
@ -39,6 +42,16 @@ class PolicyEvaluator(BaseEvaluator):
|
||||
"""Wrapper to append to messages list, which is returned with PolicyResult"""
|
||||
self._messages.append(message)
|
||||
|
||||
def expr_func_call_policy(self, name: str, **kwargs) -> PolicyResult:
|
||||
"""Call policy by name, with current request"""
|
||||
policy = Policy.objects.filter(name=name).select_subclasses().first()
|
||||
if not policy:
|
||||
raise ValueError(f"Policy '{name}' not found.")
|
||||
req: PolicyRequest = self._context["request"]
|
||||
req.context.update(kwargs)
|
||||
proc = PolicyProcess(PolicyBinding(policy=policy), request=req, connection=None)
|
||||
return proc.profiling_wrapper()
|
||||
|
||||
def expr_func_user_has_authenticator(
|
||||
self, user: User, device_type: Optional[str] = None
|
||||
) -> bool:
|
||||
|
@ -127,8 +127,8 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
)
|
||||
return policy_result
|
||||
|
||||
def run(self): # pragma: no cover
|
||||
"""Task wrapper to run policy checking"""
|
||||
def profiling_wrapper(self):
|
||||
"""Run with profiling enabled"""
|
||||
with Hub.current.start_span(
|
||||
op="policy.process.execute",
|
||||
) as span, HIST_POLICIES_EXECUTION_TIME.labels(
|
||||
@ -142,8 +142,12 @@ class PolicyProcess(PROCESS_CLASS):
|
||||
span: Span
|
||||
span.set_data("policy", self.binding.policy)
|
||||
span.set_data("request", self.request)
|
||||
try:
|
||||
self.connection.send(self.execute())
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
LOGGER.warning(str(exc))
|
||||
self.connection.send(PolicyResult(False, str(exc)))
|
||||
return self.execute()
|
||||
|
||||
def run(self): # pragma: no cover
|
||||
"""Task wrapper to run policy checking"""
|
||||
try:
|
||||
self.connection.send(self.profiling_wrapper())
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
LOGGER.warning(str(exc))
|
||||
self.connection.send(PolicyResult(False, str(exc)))
|
||||
|
@ -16,7 +16,7 @@ LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def save_ip_reputation(self: MonitoredTask):
|
||||
"""Save currently cached reputation to database"""
|
||||
objects_to_update = []
|
||||
@ -30,7 +30,7 @@ def save_ip_reputation(self: MonitoredTask):
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def save_user_reputation(self: MonitoredTask):
|
||||
"""Save currently cached reputation to database"""
|
||||
objects_to_update = []
|
||||
|
@ -23,6 +23,6 @@ def invalidate_policy_cache(sender, instance, **_):
|
||||
total += len(keys)
|
||||
cache.delete_many(keys)
|
||||
LOGGER.debug("Invalidating policy cache", policy=instance, keys=total)
|
||||
# Also delete user application cache
|
||||
keys = cache.keys(user_app_cache_key("*")) or []
|
||||
cache.delete_many(keys)
|
||||
# Also delete user application cache
|
||||
keys = cache.keys(user_app_cache_key("*")) or []
|
||||
cache.delete_many(keys)
|
||||
|
@ -10,7 +10,7 @@ from django.views.generic.base import View
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Application, Provider, User
|
||||
from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE
|
||||
from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE, SESSION_KEY_POST
|
||||
from authentik.lib.sentry import SentryIgnoredException
|
||||
from authentik.policies.denied import AccessDeniedResponse
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
@ -84,6 +84,10 @@ class PolicyAccessView(AccessMixin, View):
|
||||
a hint on the Identification Stage what the user should login for."""
|
||||
if self.application:
|
||||
self.request.session[SESSION_KEY_APPLICATION_PRE] = self.application
|
||||
# Because this view might get hit with a POST request, we need to preserve that data
|
||||
# since later views might need it (mostly SAML)
|
||||
if self.request.method.lower() == "post":
|
||||
self.request.session[SESSION_KEY_POST] = self.request.POST
|
||||
return redirect_to_login(
|
||||
self.request.get_full_path(),
|
||||
self.get_login_url(),
|
||||
|
@ -97,7 +97,7 @@ class TokenParams:
|
||||
)
|
||||
# https://tools.ietf.org/html/rfc6749#section-6
|
||||
# Fallback to original token's scopes when none are given
|
||||
if self.scope == []:
|
||||
if not self.scope:
|
||||
self.scope = self.refresh_token.scope
|
||||
except RefreshToken.DoesNotExist:
|
||||
LOGGER.warning(
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import CharField, ListField, SerializerMethodField
|
||||
from rest_framework.fields import CharField, ListField, ReadOnlyField, SerializerMethodField
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||
|
||||
@ -109,6 +109,9 @@ class ProxyProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
class ProxyOutpostConfigSerializer(ModelSerializer):
|
||||
"""Proxy provider serializer for outposts"""
|
||||
|
||||
assigned_application_slug = ReadOnlyField(source="application.slug")
|
||||
assigned_application_name = ReadOnlyField(source="application.name")
|
||||
|
||||
oidc_configuration = SerializerMethodField()
|
||||
token_validity = SerializerMethodField()
|
||||
scopes_to_request = SerializerMethodField()
|
||||
@ -152,6 +155,8 @@ class ProxyOutpostConfigSerializer(ModelSerializer):
|
||||
"cookie_domain",
|
||||
"token_validity",
|
||||
"scopes_to_request",
|
||||
"assigned_application_slug",
|
||||
"assigned_application_name",
|
||||
]
|
||||
|
||||
|
||||
|
@ -20,9 +20,11 @@ class TraefikMiddlewareSpecForwardAuth:
|
||||
|
||||
address: str
|
||||
# pylint: disable=invalid-name
|
||||
authResponseHeaders: list[str]
|
||||
authResponseHeadersRegex: str = field(default="")
|
||||
# pylint: disable=invalid-name
|
||||
trustForwardHeader: bool
|
||||
authResponseHeaders: list[str] = field(default_factory=list)
|
||||
# pylint: disable=invalid-name
|
||||
trustForwardHeader: bool = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -108,21 +110,8 @@ class TraefikMiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware])
|
||||
spec=TraefikMiddlewareSpec(
|
||||
forwardAuth=TraefikMiddlewareSpecForwardAuth(
|
||||
address=f"http://{self.name}.{self.namespace}:9000/akprox/auth/traefik",
|
||||
authResponseHeaders=[
|
||||
"Set-Cookie",
|
||||
# Legacy headers, remove after 2022.1
|
||||
"X-Auth-Username",
|
||||
"X-Auth-Groups",
|
||||
"X-Forwarded-Email",
|
||||
"X-Forwarded-Preferred-Username",
|
||||
"X-Forwarded-User",
|
||||
# New headers, unique prefix
|
||||
"X-authentik-username",
|
||||
"X-authentik-groups",
|
||||
"X-authentik-email",
|
||||
"X-authentik-name",
|
||||
"X-authentik-uid",
|
||||
],
|
||||
authResponseHeaders=[],
|
||||
authResponseHeadersRegex="^.*$",
|
||||
trustForwardHeader=True,
|
||||
)
|
||||
),
|
||||
|
@ -36,6 +36,7 @@ from authentik.flows.models import Flow, FlowDesignation
|
||||
from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider
|
||||
from authentik.providers.saml.processors.metadata import MetadataProcessor
|
||||
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
|
||||
from authentik.sources.saml.processors.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
@ -109,7 +110,17 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
name="download",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.BOOL,
|
||||
)
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="force_binding",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.STR,
|
||||
enum=[
|
||||
SAML_BINDING_REDIRECT,
|
||||
SAML_BINDING_POST,
|
||||
],
|
||||
description=("Optionally force the metadata to only include one binding."),
|
||||
),
|
||||
],
|
||||
)
|
||||
@action(methods=["GET"], detail=True, permission_classes=[AllowAny])
|
||||
@ -122,8 +133,10 @@ class SAMLProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
except ValueError:
|
||||
raise Http404
|
||||
try:
|
||||
metadata = MetadataProcessor(provider, request).build_entity_descriptor()
|
||||
if "download" in request._request.GET:
|
||||
proc = MetadataProcessor(provider, request)
|
||||
proc.force_binding = request.query_params.get("force_binding", None)
|
||||
metadata = proc.build_entity_descriptor()
|
||||
if "download" in request.query_params:
|
||||
response = HttpResponse(metadata, content_type="application/xml")
|
||||
response[
|
||||
"Content-Disposition"
|
||||
|
@ -101,7 +101,8 @@ class AssertionProcessor:
|
||||
|
||||
attribute_statement.append(attribute)
|
||||
|
||||
except PropertyMappingExpressionException as exc:
|
||||
except (PropertyMappingExpressionException, ValueError) as exc:
|
||||
# Value error can be raised when assigning invalid data to an attribute
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping: {str(exc)}",
|
||||
|
@ -29,10 +29,12 @@ class MetadataProcessor:
|
||||
|
||||
provider: SAMLProvider
|
||||
http_request: HttpRequest
|
||||
force_binding: Optional[str]
|
||||
|
||||
def __init__(self, provider: SAMLProvider, request: HttpRequest):
|
||||
self.provider = provider
|
||||
self.http_request = request
|
||||
self.force_binding = None
|
||||
self.xml_id = get_random_id()
|
||||
|
||||
def get_signing_key_descriptor(self) -> Optional[Element]:
|
||||
@ -79,6 +81,8 @@ class MetadataProcessor:
|
||||
),
|
||||
}
|
||||
for binding, url in binding_url_map.items():
|
||||
if self.force_binding and self.force_binding != binding:
|
||||
continue
|
||||
element = Element(f"{{{NS_SAML_METADATA}}}SingleSignOnService")
|
||||
element.attrib["Binding"] = binding
|
||||
element.attrib["Location"] = url
|
||||
|
@ -100,14 +100,13 @@ class AuthNRequestParser:
|
||||
xmlsec.tree.add_ids(root, ["ID"])
|
||||
signature_nodes = root.xpath("/samlp:AuthnRequest/ds:Signature", namespaces=NS_MAP)
|
||||
# No signatures, no verifier configured -> decode xml directly
|
||||
if len(signature_nodes) < 1 and not verifier:
|
||||
return self._parse_xml(decoded_xml, relay_state)
|
||||
if len(signature_nodes) < 1:
|
||||
if not verifier:
|
||||
return self._parse_xml(decoded_xml, relay_state)
|
||||
raise CannotHandleAssertion(ERROR_SIGNATURE_REQUIRED_BUT_ABSENT)
|
||||
|
||||
signature_node = signature_nodes[0]
|
||||
|
||||
if verifier and signature_node is None:
|
||||
raise CannotHandleAssertion(ERROR_SIGNATURE_REQUIRED_BUT_ABSENT)
|
||||
|
||||
if signature_node is not None:
|
||||
if not verifier:
|
||||
raise CannotHandleAssertion(ERROR_SIGNATURE_EXISTS_BUT_NO_VERIFIER)
|
||||
|
@ -13,7 +13,7 @@ from authentik.core.models import Application
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN, SESSION_KEY_POST
|
||||
from authentik.lib.utils.urls import redirect_with_qs
|
||||
from authentik.lib.views import bad_request_message
|
||||
from authentik.policies.views import PolicyAccessView
|
||||
@ -37,7 +37,7 @@ LOGGER = get_logger()
|
||||
|
||||
|
||||
class SAMLSSOView(PolicyAccessView):
|
||||
""" "SAML SSO Base View, which plans a flow and injects our final stage.
|
||||
"""SAML SSO Base View, which plans a flow and injects our final stage.
|
||||
Calls get/post handler."""
|
||||
|
||||
def resolve_provider_application(self):
|
||||
@ -120,14 +120,20 @@ class SAMLSSOBindingPOSTView(SAMLSSOView):
|
||||
|
||||
def check_saml_request(self) -> Optional[HttpRequest]:
|
||||
"""Handle POST bindings"""
|
||||
if REQUEST_KEY_SAML_REQUEST not in self.request.POST:
|
||||
payload = self.request.POST
|
||||
# Restore the post body from the session
|
||||
# This happens when using POST bindings but the user isn't logged in
|
||||
# (user gets redirected and POST body is 'lost')
|
||||
if SESSION_KEY_POST in self.request.session:
|
||||
payload = self.request.session.pop(SESSION_KEY_POST)
|
||||
if REQUEST_KEY_SAML_REQUEST not in payload:
|
||||
LOGGER.info("check_saml_request: SAML payload missing")
|
||||
return bad_request_message(self.request, "The SAML request payload is missing.")
|
||||
|
||||
try:
|
||||
auth_n_request = AuthNRequestParser(self.provider).parse(
|
||||
self.request.POST[REQUEST_KEY_SAML_REQUEST],
|
||||
self.request.POST.get(REQUEST_KEY_RELAY_STATE),
|
||||
payload[REQUEST_KEY_SAML_REQUEST],
|
||||
payload.get(REQUEST_KEY_RELAY_STATE),
|
||||
)
|
||||
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
|
||||
except CannotHandleAssertion as exc:
|
||||
|
@ -14,6 +14,7 @@ from celery.signals import (
|
||||
from django.conf import settings
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.middleware import LOCAL
|
||||
from authentik.lib.sentry import before_send
|
||||
from authentik.lib.utils.errors import exception_to_string
|
||||
|
||||
@ -26,7 +27,7 @@ CELERY_APP = Celery("authentik")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@setup_logging.connect
|
||||
def config_loggers(*args, **kwags):
|
||||
def config_loggers(*args, **kwargs):
|
||||
"""Apply logging settings from settings.py to celery"""
|
||||
dictConfig(settings.LOGGING)
|
||||
|
||||
@ -36,21 +37,29 @@ def config_loggers(*args, **kwags):
|
||||
def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs):
|
||||
"""Log task_id after it was published"""
|
||||
info = headers if "task" in headers else body
|
||||
LOGGER.debug("Task published", task_id=info.get("id", ""), task_name=info.get("task", ""))
|
||||
LOGGER.info("Task published", task_id=info.get("id", ""), task_name=info.get("task", ""))
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@task_prerun.connect
|
||||
def task_prerun_hook(task_id, task, *args, **kwargs):
|
||||
def task_prerun_hook(task_id: str, task, *args, **kwargs):
|
||||
"""Log task_id on worker"""
|
||||
LOGGER.debug("Task started", task_id=task_id, task_name=task.__name__)
|
||||
request_id = "task-" + task_id.replace("-", "")
|
||||
LOCAL.authentik_task = {
|
||||
"request_id": request_id,
|
||||
}
|
||||
LOGGER.info("Task started", task_id=task_id, task_name=task.__name__)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@task_postrun.connect
|
||||
def task_postrun_hook(task_id, task, *args, retval=None, state=None, **kwargs):
|
||||
"""Log task_id on worker"""
|
||||
LOGGER.debug("Task finished", task_id=task_id, task_name=task.__name__, state=state)
|
||||
LOGGER.info("Task finished", task_id=task_id, task_name=task.__name__, state=state)
|
||||
if not hasattr(LOCAL, "authentik_task"):
|
||||
return
|
||||
for key in list(LOCAL.authentik_task.keys()):
|
||||
del LOCAL.authentik_task[key]
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
@ -24,6 +24,7 @@ import structlog
|
||||
from celery.schedules import crontab
|
||||
from sentry_sdk import init as sentry_init
|
||||
from sentry_sdk.api import set_tag
|
||||
from sentry_sdk.integrations.boto3 import Boto3Integration
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.django import DjangoIntegration
|
||||
from sentry_sdk.integrations.redis import RedisIntegration
|
||||
@ -231,6 +232,7 @@ CACHES = {
|
||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||
}
|
||||
}
|
||||
DJANGO_REDIS_SCAN_ITERSIZE = 1000
|
||||
DJANGO_REDIS_IGNORE_EXCEPTIONS = True
|
||||
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = True
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
|
||||
@ -421,10 +423,11 @@ if _ERROR_REPORTING:
|
||||
DjangoIntegration(transaction_style="function_name"),
|
||||
CeleryIntegration(),
|
||||
RedisIntegration(),
|
||||
Boto3Integration(),
|
||||
],
|
||||
before_send=before_send,
|
||||
release=f"authentik@{__version__}",
|
||||
traces_sample_rate=float(CONFIG.y("error_reporting.sample_rate", 0.4)),
|
||||
traces_sample_rate=float(CONFIG.y("error_reporting.sample_rate", 0.5)),
|
||||
environment=CONFIG.y("error_reporting.environment", "customer"),
|
||||
send_default_pii=CONFIG.y_bool("error_reporting.send_pii", False),
|
||||
)
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Integrate ./manage.py test with pytest"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
@ -8,34 +10,43 @@ from tests.e2e.utils import get_docker_tag
|
||||
class PytestTestRunner: # pragma: no cover
|
||||
"""Runs pytest to discover and run tests."""
|
||||
|
||||
def __init__(self, verbosity=1, failfast=False, keepdb=False, **_):
|
||||
def __init__(self, verbosity=1, failfast=False, keepdb=False, **kwargs):
|
||||
self.verbosity = verbosity
|
||||
self.failfast = failfast
|
||||
self.keepdb = keepdb
|
||||
|
||||
self.args = ["-vv"]
|
||||
if self.failfast:
|
||||
self.args.append("--exitfirst")
|
||||
if self.keepdb:
|
||||
self.args.append("--reuse-db")
|
||||
|
||||
if kwargs.get("randomly_seed", None):
|
||||
self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
|
||||
|
||||
settings.TEST = True
|
||||
settings.CELERY_TASK_ALWAYS_EAGER = True
|
||||
CONFIG.y_set("authentik.avatars", "none")
|
||||
CONFIG.y_set("authentik.geoip", "tests/GeoLite2-City-Test.mmdb")
|
||||
CONFIG.y_set(
|
||||
"outposts.container_image_base",
|
||||
f"goauthentik.io/dev-%(type)s:{get_docker_tag()}",
|
||||
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: ArgumentParser):
|
||||
"""Add more pytest-specific arguments"""
|
||||
parser.add_argument("--randomly-seed", type=int)
|
||||
|
||||
def run_tests(self, test_labels):
|
||||
"""Run pytest and return the exitcode.
|
||||
|
||||
It translates some of Django's test command option to pytest's.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
argv = ["-vv"]
|
||||
if self.failfast:
|
||||
argv.append("--exitfirst")
|
||||
if self.keepdb:
|
||||
argv.append("--reuse-db")
|
||||
|
||||
if any("tests/e2e" in label for label in test_labels):
|
||||
argv.append("-pno:randomly")
|
||||
|
||||
argv.extend(test_labels)
|
||||
return pytest.main(argv)
|
||||
self.args.append("-pno:randomly")
|
||||
self.args.extend(test_labels)
|
||||
return pytest.main(self.args)
|
||||
|
@ -43,6 +43,7 @@ class LDAPSourceSerializer(SourceSerializer):
|
||||
model = LDAPSource
|
||||
fields = SourceSerializer.Meta.fields + [
|
||||
"server_uri",
|
||||
"peer_certificate",
|
||||
"bind_cn",
|
||||
"bind_password",
|
||||
"start_tls",
|
||||
@ -73,11 +74,9 @@ class LDAPSourceViewSet(UsedByMixin, ModelViewSet):
|
||||
"name",
|
||||
"slug",
|
||||
"enabled",
|
||||
"authentication_flow",
|
||||
"enrollment_flow",
|
||||
"policy_engine_mode",
|
||||
"server_uri",
|
||||
"bind_cn",
|
||||
"peer_certificate",
|
||||
"start_tls",
|
||||
"base_dn",
|
||||
"additional_user_dn",
|
||||
|
@ -58,7 +58,7 @@ class LDAPBackend(InbuiltBackend):
|
||||
LOGGER.debug("Attempting Binding as user", user=user)
|
||||
try:
|
||||
temp_connection = ldap3.Connection(
|
||||
source.connection.server,
|
||||
source.server,
|
||||
user=user.attributes.get(LDAP_DISTINGUISHED_NAME),
|
||||
password=password,
|
||||
raise_exceptions=True,
|
||||
|
38
authentik/sources/ldap/migrations/0002_auto_20211203_0900.py
Normal file
38
authentik/sources/ldap/migrations/0002_auto_20211203_0900.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Generated by Django 3.2.9 on 2021-12-03 09:00
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import authentik.sources.ldap.models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_crypto", "0003_certificatekeypair_managed"),
|
||||
("authentik_sources_ldap", "0001_squashed_0012_auto_20210812_1703"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="ldapsource",
|
||||
name="peer_certificate",
|
||||
field=models.ForeignKey(
|
||||
default=None,
|
||||
help_text="Optionally verify the LDAP Server's Certificate against the CA Chain in this keypair.",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_DEFAULT,
|
||||
to="authentik_crypto.certificatekeypair",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="ldapsource",
|
||||
name="server_uri",
|
||||
field=models.TextField(
|
||||
validators=[
|
||||
authentik.sources.ldap.models.MultiURLValidator(schemes=["ldap", "ldaps"])
|
||||
],
|
||||
verbose_name="Server URI",
|
||||
),
|
||||
),
|
||||
]
|
@ -1,24 +1,48 @@
|
||||
"""authentik LDAP Models"""
|
||||
from typing import Optional, Type
|
||||
from ssl import CERT_REQUIRED
|
||||
from typing import Type
|
||||
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from ldap3 import ALL, Connection, Server
|
||||
from ldap3 import ALL, RANDOM, Connection, Server, ServerPool, Tls
|
||||
from rest_framework.serializers import Serializer
|
||||
|
||||
from authentik.core.models import Group, PropertyMapping, Source
|
||||
from authentik.crypto.models import CertificateKeyPair
|
||||
from authentik.lib.models import DomainlessURLValidator
|
||||
|
||||
LDAP_TIMEOUT = 15
|
||||
|
||||
|
||||
class MultiURLValidator(DomainlessURLValidator):
|
||||
"""Same as DomainlessURLValidator but supports multiple URLs separated with a comma."""
|
||||
|
||||
def __call__(self, value: str):
|
||||
if "," in value:
|
||||
for url in value.split(","):
|
||||
super().__call__(url)
|
||||
else:
|
||||
super().__call__(value)
|
||||
|
||||
|
||||
class LDAPSource(Source):
|
||||
"""Federate LDAP Directory with authentik, or create new accounts in LDAP."""
|
||||
|
||||
server_uri = models.TextField(
|
||||
validators=[DomainlessURLValidator(schemes=["ldap", "ldaps"])],
|
||||
validators=[MultiURLValidator(schemes=["ldap", "ldaps"])],
|
||||
verbose_name=_("Server URI"),
|
||||
)
|
||||
peer_certificate = models.ForeignKey(
|
||||
CertificateKeyPair,
|
||||
on_delete=models.SET_DEFAULT,
|
||||
default=None,
|
||||
null=True,
|
||||
help_text=_(
|
||||
"Optionally verify the LDAP Server's Certificate "
|
||||
"against the CA Chain in this keypair."
|
||||
),
|
||||
)
|
||||
|
||||
bind_cn = models.TextField(verbose_name=_("Bind CN"), blank=True)
|
||||
bind_password = models.TextField(blank=True)
|
||||
start_tls = models.BooleanField(default=False, verbose_name=_("Enable Start TLS"))
|
||||
@ -82,25 +106,40 @@ class LDAPSource(Source):
|
||||
|
||||
return LDAPSourceSerializer
|
||||
|
||||
_connection: Optional[Connection] = None
|
||||
@property
|
||||
def server(self) -> Server:
|
||||
"""Get LDAP Server/ServerPool"""
|
||||
servers = []
|
||||
tls = Tls()
|
||||
if self.peer_certificate:
|
||||
tls = Tls(ca_certs_data=self.peer_certificate.certificate_data, validate=CERT_REQUIRED)
|
||||
kwargs = {
|
||||
"get_info": ALL,
|
||||
"connect_timeout": LDAP_TIMEOUT,
|
||||
"tls": tls,
|
||||
}
|
||||
if "," in self.server_uri:
|
||||
for server in self.server_uri.split(","):
|
||||
servers.append(Server(server, **kwargs))
|
||||
else:
|
||||
servers = [Server(self.server_uri, **kwargs)]
|
||||
return ServerPool(servers, RANDOM, active=True, exhaust=True)
|
||||
|
||||
@property
|
||||
def connection(self) -> Connection:
|
||||
"""Get a fully connected and bound LDAP Connection"""
|
||||
if not self._connection:
|
||||
server = Server(self.server_uri, get_info=ALL, connect_timeout=LDAP_TIMEOUT)
|
||||
self._connection = Connection(
|
||||
server,
|
||||
raise_exceptions=True,
|
||||
user=self.bind_cn,
|
||||
password=self.bind_password,
|
||||
receive_timeout=LDAP_TIMEOUT,
|
||||
)
|
||||
connection = Connection(
|
||||
self.server,
|
||||
raise_exceptions=True,
|
||||
user=self.bind_cn,
|
||||
password=self.bind_password,
|
||||
receive_timeout=LDAP_TIMEOUT,
|
||||
)
|
||||
|
||||
self._connection.bind()
|
||||
if self.start_tls:
|
||||
self._connection.start_tls()
|
||||
return self._connection
|
||||
connection.bind()
|
||||
if self.start_tls:
|
||||
connection.start_tls()
|
||||
return connection
|
||||
|
||||
class Meta:
|
||||
|
||||
|
@ -51,7 +51,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
},
|
||||
defaults,
|
||||
)
|
||||
except (IntegrityError, FieldError) as exc:
|
||||
except (IntegrityError, FieldError, TypeError) as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=(
|
||||
|
@ -45,7 +45,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer):
|
||||
ak_user, created = self.update_or_create_attributes(
|
||||
User, {f"attributes__{LDAP_UNIQUENESS}": uniq}, defaults
|
||||
)
|
||||
except (IntegrityError, FieldError) as exc:
|
||||
except (IntegrityError, FieldError, TypeError) as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=(
|
||||
|
@ -39,7 +39,7 @@ def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str):
|
||||
# to set the state with
|
||||
return
|
||||
sync = path_to_class(sync_class)
|
||||
self.set_uid(f"{slugify(source.name)}-{sync.__name__}")
|
||||
self.set_uid(f"{slugify(source.name)}_{sync.__name__.replace('LDAPSynchronizer', '').lower()}")
|
||||
try:
|
||||
sync_inst = sync(source)
|
||||
count = sync_inst.sync()
|
||||
|
@ -120,9 +120,9 @@ class LDAPSyncTests(TestCase):
|
||||
self.source.property_mappings_group.set(
|
||||
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name")
|
||||
)
|
||||
self.source.save()
|
||||
connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
self.source.save()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync.sync()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
@ -143,9 +143,9 @@ class LDAPSyncTests(TestCase):
|
||||
self.source.property_mappings_group.set(
|
||||
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn")
|
||||
)
|
||||
self.source.save()
|
||||
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
self.source.save()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
group_sync.sync()
|
||||
membership_sync = MembershipLDAPSynchronizer(self.source)
|
||||
@ -168,9 +168,9 @@ class LDAPSyncTests(TestCase):
|
||||
self.source.property_mappings_group.set(
|
||||
LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn")
|
||||
)
|
||||
self.source.save()
|
||||
connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD))
|
||||
with patch("authentik.sources.ldap.models.LDAPSource.connection", connection):
|
||||
self.source.save()
|
||||
user_sync = UserLDAPSynchronizer(self.source)
|
||||
user_sync.sync()
|
||||
group_sync = GroupLDAPSynchronizer(self.source)
|
||||
|
@ -1,10 +1,9 @@
|
||||
"""OAuth Source Serializer"""
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from rest_framework import mixins
|
||||
from rest_framework.filters import OrderingFilter, SearchFilter
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.api.authorization import OwnerFilter, OwnerPermissions
|
||||
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.sources.oauth.models import UserOAuthSourceConnection
|
||||
@ -15,30 +14,19 @@ class UserOAuthSourceConnectionSerializer(SourceSerializer):
|
||||
|
||||
class Meta:
|
||||
model = UserOAuthSourceConnection
|
||||
fields = [
|
||||
"pk",
|
||||
"user",
|
||||
"source",
|
||||
"identifier",
|
||||
]
|
||||
fields = ["pk", "user", "source", "identifier", "access_token"]
|
||||
extra_kwargs = {
|
||||
"user": {"read_only": True},
|
||||
"access_token": {"write_only": True},
|
||||
}
|
||||
|
||||
|
||||
class UserOAuthSourceConnectionViewSet(
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.UpdateModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class UserOAuthSourceConnectionViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Source Viewset"""
|
||||
|
||||
queryset = UserOAuthSourceConnection.objects.all()
|
||||
serializer_class = UserOAuthSourceConnectionSerializer
|
||||
filterset_fields = ["source__slug"]
|
||||
permission_classes = [OwnerPermissions]
|
||||
permission_classes = [OwnerSuperuserPermissions]
|
||||
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
|
||||
ordering = ["source__slug"]
|
||||
|
@ -1,10 +1,9 @@
|
||||
"""Plex Source connection Serializer"""
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from rest_framework import mixins
|
||||
from rest_framework.filters import OrderingFilter, SearchFilter
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.api.authorization import OwnerFilter, OwnerPermissions
|
||||
from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.sources.plex.models import PlexSourceConnection
|
||||
@ -27,19 +26,12 @@ class PlexSourceConnectionSerializer(SourceSerializer):
|
||||
}
|
||||
|
||||
|
||||
class PlexSourceConnectionViewSet(
|
||||
mixins.RetrieveModelMixin,
|
||||
mixins.UpdateModelMixin,
|
||||
mixins.DestroyModelMixin,
|
||||
UsedByMixin,
|
||||
mixins.ListModelMixin,
|
||||
GenericViewSet,
|
||||
):
|
||||
class PlexSourceConnectionViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Plex Source connection Serializer"""
|
||||
|
||||
queryset = PlexSourceConnection.objects.all()
|
||||
serializer_class = PlexSourceConnectionSerializer
|
||||
filterset_fields = ["source__slug"]
|
||||
permission_classes = [OwnerPermissions]
|
||||
permission_classes = [OwnerSuperuserPermissions]
|
||||
filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter]
|
||||
ordering = ["pk"]
|
||||
|
@ -29,14 +29,15 @@ def check_plex_token(self: MonitoredTask, source_slug: int):
|
||||
auth.get_user_info()
|
||||
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]))
|
||||
except RequestException as exc:
|
||||
error = exception_to_string(exc).replace(source.plex_token, "$PLEX_TOKEN")
|
||||
self.set_status(
|
||||
TaskResult(
|
||||
TaskResultStatus.ERROR,
|
||||
["Plex token is invalid/an error occurred:", exception_to_string(exc)],
|
||||
["Plex token is invalid/an error occurred:", error],
|
||||
)
|
||||
)
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Plex token invalid, please re-authenticate source.\n{str(exc)}",
|
||||
message=f"Plex token invalid, please re-authenticate source.\n{error}",
|
||||
source=source,
|
||||
).save()
|
||||
|
@ -17,7 +17,7 @@ LOGGER = get_logger()
|
||||
|
||||
|
||||
@CELERY_APP.task(bind=True, base=MonitoredTask)
|
||||
@prefill_task()
|
||||
@prefill_task
|
||||
def clean_temporary_users(self: MonitoredTask):
|
||||
"""Remove temporary users created by SAML Sources"""
|
||||
_now = now()
|
||||
|
@ -53,9 +53,6 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
|
||||
|
||||
def validate_response(self, response: dict) -> dict:
|
||||
"""Validate webauthn challenge response"""
|
||||
# pylint: disable=no-name-in-module
|
||||
from pydantic.error_wrappers import ValidationError as PydanticValidationError
|
||||
|
||||
challenge = self.request.session["challenge"]
|
||||
|
||||
try:
|
||||
@ -65,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
|
||||
expected_rp_id=get_rp_id(self.request),
|
||||
expected_origin=get_origin(self.request),
|
||||
)
|
||||
except (InvalidRegistrationResponse, PydanticValidationError) as exc:
|
||||
except InvalidRegistrationResponse as exc:
|
||||
LOGGER.warning("registration failed", exc=exc)
|
||||
raise ValidationError(f"Registration failed. Error: {exc}")
|
||||
|
||||
|
@ -12,17 +12,16 @@ from rest_framework.fields import CharField
|
||||
from rest_framework.serializers import ValidationError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.models import Token
|
||||
from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.models import FlowToken
|
||||
from authentik.flows.planner import PLAN_CONTEXT_IS_RESTORED, PLAN_CONTEXT_PENDING_USER
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.flows.views.executor import SESSION_KEY_GET
|
||||
from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_GET
|
||||
from authentik.stages.email.models import EmailStage
|
||||
from authentik.stages.email.tasks import send_mails
|
||||
from authentik.stages.email.utils import TemplateEmailMessage
|
||||
|
||||
LOGGER = get_logger()
|
||||
QS_KEY_TOKEN = "etoken" # nosec
|
||||
PLAN_CONTEXT_EMAIL_SENT = "email_sent"
|
||||
|
||||
|
||||
@ -56,7 +55,7 @@ class EmailStageView(ChallengeStageView):
|
||||
relative_url = f"{base_url}?{urlencode(kwargs)}"
|
||||
return self.request.build_absolute_uri(relative_url)
|
||||
|
||||
def get_token(self) -> Token:
|
||||
def get_token(self) -> FlowToken:
|
||||
"""Get token"""
|
||||
pending_user = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||
current_stage: EmailStage = self.executor.current_stage
|
||||
@ -65,10 +64,14 @@ class EmailStageView(ChallengeStageView):
|
||||
) # + 1 because django timesince always rounds down
|
||||
identifier = slugify(f"ak-email-stage-{current_stage.name}-{pending_user}")
|
||||
# Don't check for validity here, we only care if the token exists
|
||||
tokens = Token.objects.filter(identifier=identifier)
|
||||
tokens = FlowToken.objects.filter(identifier=identifier)
|
||||
if not tokens.exists():
|
||||
return Token.objects.create(
|
||||
expires=now() + valid_delta, user=pending_user, identifier=identifier
|
||||
return FlowToken.objects.create(
|
||||
expires=now() + valid_delta,
|
||||
user=pending_user,
|
||||
identifier=identifier,
|
||||
flow=self.executor.flow,
|
||||
_plan=FlowToken.pickle(self.executor.plan),
|
||||
)
|
||||
token = tokens.first()
|
||||
# Check if token is expired and rotate key if so
|
||||
@ -97,13 +100,9 @@ class EmailStageView(ChallengeStageView):
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
# Check if the user came back from the email link to verify
|
||||
if QS_KEY_TOKEN in request.session.get(SESSION_KEY_GET, {}):
|
||||
tokens = Token.filter_not_expired(key=request.session[SESSION_KEY_GET][QS_KEY_TOKEN])
|
||||
if not tokens.exists():
|
||||
return self.executor.stage_invalid(_("Invalid token"))
|
||||
token = tokens.first()
|
||||
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = token.user
|
||||
token.delete()
|
||||
if QS_KEY_TOKEN in request.session.get(
|
||||
SESSION_KEY_GET, {}
|
||||
) and self.executor.plan.context.get(PLAN_CONTEXT_IS_RESTORED, False):
|
||||
messages.success(request, _("Successfully verified Email."))
|
||||
if self.executor.current_stage.activate_user_on_success:
|
||||
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER].is_active = True
|
||||
|
43
authentik/stages/prompt/migrations/0006_alter_prompt_type.py
Normal file
43
authentik/stages/prompt/migrations/0006_alter_prompt_type.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Generated by Django 3.2.9 on 2021-12-03 09:00
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_stages_prompt", "0005_alter_prompt_field_key"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="prompt",
|
||||
name="type",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("text", "Text: Simple Text input"),
|
||||
(
|
||||
"text_read_only",
|
||||
"Text (read-only): Simple Text input, but cannot be edited.",
|
||||
),
|
||||
(
|
||||
"username",
|
||||
"Username: Same as Text input, but checks for and prevents duplicate usernames.",
|
||||
),
|
||||
("email", "Email: Text field with Email type."),
|
||||
(
|
||||
"password",
|
||||
"Password: Masked input, password is validated against sources. Policies still have to be applied to this Stage. If two of these are used in the same stage, they are ensured to be identical.",
|
||||
),
|
||||
("number", "Number"),
|
||||
("checkbox", "Checkbox"),
|
||||
("date", "Date"),
|
||||
("date-time", "Date Time"),
|
||||
("separator", "Separator: Static Separator Line"),
|
||||
("hidden", "Hidden: Hidden field, can be used to insert data into form."),
|
||||
("static", "Static: Static value, displayed as-is."),
|
||||
],
|
||||
max_length=100,
|
||||
),
|
||||
),
|
||||
]
|
@ -113,6 +113,9 @@ class Prompt(SerializerModel):
|
||||
kwargs["label"] = ""
|
||||
if default:
|
||||
kwargs["default"] = default
|
||||
# May not set both `required` and `default`
|
||||
if "default" in kwargs:
|
||||
kwargs.pop("required", None)
|
||||
return field_class(**kwargs)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
|
@ -18,7 +18,7 @@ from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTyp
|
||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
|
||||
from authentik.flows.stage import ChallengeStageView
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel
|
||||
from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode
|
||||
from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage
|
||||
from authentik.stages.prompt.signals import password_validate
|
||||
|
||||
@ -110,6 +110,7 @@ class PromptChallengeResponse(ChallengeResponse):
|
||||
|
||||
user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user())
|
||||
engine = ListPolicyEngine(self.stage.validation_policies.all(), user, self.request)
|
||||
engine.mode = PolicyEngineMode.MODE_ALL
|
||||
engine.request.context[PLAN_CONTEXT_PROMPT] = attrs
|
||||
engine.request.context.update(attrs)
|
||||
engine.build()
|
||||
|
@ -17,7 +17,7 @@ services:
|
||||
image: redis:alpine
|
||||
restart: unless-stopped
|
||||
server:
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.12.1-rc1}
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.12.1-rc4}
|
||||
restart: unless-stopped
|
||||
command: server
|
||||
environment:
|
||||
@ -38,7 +38,7 @@ services:
|
||||
- "0.0.0.0:9000:9000"
|
||||
- "0.0.0.0:9443:9443"
|
||||
worker:
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.12.1-rc1}
|
||||
image: ${AUTHENTIK_IMAGE:-goauthentik.io/server}:${AUTHENTIK_TAG:-2021.12.1-rc4}
|
||||
restart: unless-stopped
|
||||
command: worker
|
||||
environment:
|
||||
@ -55,6 +55,7 @@ services:
|
||||
volumes:
|
||||
- ./backups:/backups
|
||||
- ./media:/media
|
||||
- ./certs:/certs
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./custom-templates:/templates
|
||||
- geoip:/geoip
|
||||
|
4
go.mod
4
go.mod
@ -27,12 +27,12 @@ require (
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/recws-org/recws v1.3.1
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
goauthentik.io/api v0.2021104.6
|
||||
goauthentik.io/api v0.2021104.11
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20210323180902-22b0adad7558
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
gopkg.in/boj/redistore.v1 v1.0.0-20160128113310-fc113767cd6b
|
||||
gopkg.in/square/go-jose.v2 v2.5.1 // indirect
|
||||
|
9
go.sum
9
go.sum
@ -356,7 +356,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC
|
||||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
|
||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
@ -481,8 +480,6 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT
|
||||
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
|
||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/recws-org/recws v1.3.1 h1:vtRhYpgNPBs3iFyu/+zxBqNzLYgID7UPC5siThkvbs0=
|
||||
github.com/recws-org/recws v1.3.1/go.mod h1:gRH/uJLMsO7lbcecAB1Im1Zc6eKxs93ftGR0R39QeYA=
|
||||
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
@ -561,8 +558,8 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
goauthentik.io/api v0.2021104.6 h1:1Vyw1gnVm9D7htUXWTcy7Gg7ldU0V0vIhT8RFo9G/Iw=
|
||||
goauthentik.io/api v0.2021104.6/go.mod h1:02nnD4FRd8lu8A1+ZuzqownBgvAhdCKzqkKX8v7JMTE=
|
||||
goauthentik.io/api v0.2021104.11 h1:LqT0LM0e/RRrxPuo6Xl5uz3PCR5ytuE+YlNlfW9w0yU=
|
||||
goauthentik.io/api v0.2021104.11/go.mod h1:02nnD4FRd8lu8A1+ZuzqownBgvAhdCKzqkKX8v7JMTE=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@ -672,6 +669,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
@ -16,9 +16,8 @@ func DefaultConfig() {
|
||||
G = Config{
|
||||
Debug: false,
|
||||
Web: WebConfig{
|
||||
Listen: "localhost:9000",
|
||||
ListenTLS: "localhost:9443",
|
||||
LoadLocalFiles: false,
|
||||
Listen: "localhost:9000",
|
||||
ListenTLS: "localhost:9443",
|
||||
},
|
||||
Paths: PathsConfig{
|
||||
Media: "./media",
|
||||
|
@ -30,7 +30,6 @@ type WebConfig struct {
|
||||
Listen string `yaml:"listen"`
|
||||
ListenTLS string `yaml:"listen_tls"`
|
||||
ListenMetrics string `yaml:"listen_metrics"`
|
||||
LoadLocalFiles bool `yaml:"load_local_files" env:"AUTHENTIK_WEB_LOAD_LOCAL_FILES"`
|
||||
DisableEmbeddedOutpost bool `yaml:"disable_embedded_outpost" env:"AUTHENTIK_WEB__DISABLE_EMBEDDED_OUTPOST"`
|
||||
}
|
||||
|
||||
|
@ -17,4 +17,4 @@ func OutpostUserAgent() string {
|
||||
return fmt.Sprintf("authentik-outpost@%s (build=%s)", VERSION, BUILD())
|
||||
}
|
||||
|
||||
const VERSION = "2021.12.1-rc1"
|
||||
const VERSION = "2021.12.1-rc4"
|
||||
|
@ -11,10 +11,10 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/recws-org/recws"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/constants"
|
||||
|
||||
@ -35,20 +35,25 @@ type APIController struct {
|
||||
|
||||
logger *log.Entry
|
||||
|
||||
reloadOffset time.Duration
|
||||
lastWsReconnect time.Time
|
||||
reloadOffset time.Duration
|
||||
|
||||
wsConn *websocket.Conn
|
||||
lastWsReconnect time.Time
|
||||
wsIsReconnecting bool
|
||||
wsBackoffMultiplier int
|
||||
|
||||
wsConn *recws.RecConn
|
||||
instanceUUID uuid.UUID
|
||||
}
|
||||
|
||||
// NewAPIController initialise new API Controller instance from URL and API token
|
||||
func NewAPIController(akURL url.URL, token string) *APIController {
|
||||
rsp := sentry.StartSpan(context.TODO(), "authentik.outposts.init")
|
||||
|
||||
config := api.NewConfiguration()
|
||||
config.Host = akURL.Host
|
||||
config.Scheme = akURL.Scheme
|
||||
config.HTTPClient = &http.Client{
|
||||
Transport: NewUserAgentTransport(constants.OutpostUserAgent(), NewTracingTransport(context.TODO(), GetTLSTransport())),
|
||||
Transport: NewUserAgentTransport(constants.OutpostUserAgent(), NewTracingTransport(rsp.Context(), GetTLSTransport())),
|
||||
}
|
||||
config.AddDefaultHeader("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
|
||||
@ -85,12 +90,16 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
||||
token: token,
|
||||
logger: log,
|
||||
|
||||
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
|
||||
instanceUUID: uuid.New(),
|
||||
Outpost: outpost,
|
||||
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
|
||||
instanceUUID: uuid.New(),
|
||||
Outpost: outpost,
|
||||
wsBackoffMultiplier: 1,
|
||||
}
|
||||
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
|
||||
ac.initWS(akURL, strfmt.UUID(outpost.Pk))
|
||||
err = ac.initWS(akURL, outpost.Pk)
|
||||
if err != nil {
|
||||
go ac.reconnectWS()
|
||||
}
|
||||
ac.configureRefreshSignal()
|
||||
return ac
|
||||
}
|
||||
@ -148,10 +157,6 @@ func (a *APIController) StartBackgorundTasks() error {
|
||||
"version": constants.VERSION,
|
||||
"build": constants.BUILD(),
|
||||
}).Set(1)
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS re-connector...")
|
||||
a.startWSReConnector()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS Handler...")
|
||||
a.startWSHandler()
|
||||
|
@ -6,18 +6,17 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/recws-org/recws"
|
||||
"goauthentik.io/internal/constants"
|
||||
)
|
||||
|
||||
func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/"
|
||||
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/?%s"
|
||||
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
|
||||
|
||||
authHeader := fmt.Sprintf("Bearer %s", ac.token)
|
||||
@ -32,15 +31,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
|
||||
value = "false"
|
||||
}
|
||||
|
||||
ws := &recws.RecConn{
|
||||
NonVerbose: true,
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: strings.ToLower(value) == "true",
|
||||
},
|
||||
}
|
||||
ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header)
|
||||
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("Connecting to authentik")
|
||||
ws, _, err := dialer.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID, akURL.Query().Encode()), header)
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Warning("failed to connect websocket")
|
||||
return err
|
||||
}
|
||||
|
||||
ac.wsConn = ws
|
||||
// Send hello message with our version
|
||||
@ -52,11 +55,14 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
|
||||
"uuid": ac.instanceUUID.String(),
|
||||
},
|
||||
}
|
||||
err := ws.WriteJSON(msg)
|
||||
err = ws.WriteJSON(msg)
|
||||
if err != nil {
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik")
|
||||
return err
|
||||
}
|
||||
ac.lastWsReconnect = time.Now()
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown Gracefully stops all workers, disconnects from websocket
|
||||
@ -65,21 +71,43 @@ func (ac *APIController) Shutdown() {
|
||||
// waiting (with timeout) for the server to close the connection.
|
||||
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
ac.logger.Println("write close:", err)
|
||||
ac.logger.WithError(err).Warning("failed to write close message")
|
||||
return
|
||||
}
|
||||
err = ac.wsConn.Close()
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Warning("failed to close websocket")
|
||||
}
|
||||
ac.logger.Info("finished shutdown")
|
||||
}
|
||||
|
||||
func (ac *APIController) startWSReConnector() {
|
||||
func (ac *APIController) reconnectWS() {
|
||||
if ac.wsIsReconnecting {
|
||||
return
|
||||
}
|
||||
ac.wsIsReconnecting = true
|
||||
u := url.URL{
|
||||
Host: ac.Client.GetConfig().Host,
|
||||
Scheme: ac.Client.GetConfig().Scheme,
|
||||
}
|
||||
attempt := 1
|
||||
for {
|
||||
time.Sleep(time.Second * 5)
|
||||
if ac.wsConn.IsConnected() {
|
||||
continue
|
||||
}
|
||||
if time.Since(ac.lastWsReconnect).Seconds() > 30 {
|
||||
ac.wsConn.CloseAndReconnect()
|
||||
ac.logger.Info("Reconnecting websocket")
|
||||
ac.lastWsReconnect = time.Now()
|
||||
q := u.Query()
|
||||
q.Set("attempt", strconv.Itoa(attempt))
|
||||
u.RawQuery = q.Encode()
|
||||
err := ac.initWS(u, ac.Outpost.Pk)
|
||||
attempt += 1
|
||||
if err != nil {
|
||||
ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier)
|
||||
time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second)
|
||||
ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2
|
||||
// Limit to 300 seconds (5m)
|
||||
if ac.wsBackoffMultiplier >= 300 {
|
||||
ac.wsBackoffMultiplier = 300
|
||||
}
|
||||
} else {
|
||||
ac.wsIsReconnecting = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -88,6 +116,11 @@ func (ac *APIController) startWSHandler() {
|
||||
logger := ac.logger.WithField("loop", "ws-handler")
|
||||
for {
|
||||
var wsMsg websocketMessage
|
||||
if ac.wsConn == nil {
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
err := ac.wsConn.ReadJSON(&wsMsg)
|
||||
if err != nil {
|
||||
ConnectionStatus.With(prometheus.Labels{
|
||||
@ -96,6 +129,7 @@ func (ac *APIController) startWSHandler() {
|
||||
"uuid": ac.instanceUUID.String(),
|
||||
}).Set(0)
|
||||
logger.WithError(err).Warning("ws read error")
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
@ -126,9 +160,6 @@ func (ac *APIController) startWSHandler() {
|
||||
func (ac *APIController) startWSHealth() {
|
||||
ticker := time.NewTicker(time.Second * 10)
|
||||
for ; true; <-ticker.C {
|
||||
if !ac.wsConn.IsConnected() {
|
||||
continue
|
||||
}
|
||||
aliveMsg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: map[string]interface{}{
|
||||
@ -137,10 +168,16 @@ func (ac *APIController) startWSHealth() {
|
||||
"uuid": ac.instanceUUID.String(),
|
||||
},
|
||||
}
|
||||
if ac.wsConn == nil {
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
err := ac.wsConn.WriteJSON(aliveMsg)
|
||||
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
|
||||
if err != nil {
|
||||
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
} else {
|
||||
|
@ -73,7 +73,7 @@ func NewFlowExecutor(ctx context.Context, flowSlug string, refConfig *api.Config
|
||||
config.Scheme = refConfig.Scheme
|
||||
config.HTTPClient = &http.Client{
|
||||
Jar: jar,
|
||||
Transport: ak.NewUserAgentTransport(constants.OutpostUserAgent(), ak.NewTracingTransport(ctx, ak.GetTLSTransport())),
|
||||
Transport: ak.NewUserAgentTransport(constants.OutpostUserAgent(), ak.NewTracingTransport(rsp.Context(), ak.GetTLSTransport())),
|
||||
}
|
||||
token := strings.Split(refConfig.DefaultHeader["Authorization"], " ")[1]
|
||||
config.AddDefaultHeader(HeaderAuthentikOutpostToken, token)
|
||||
|
@ -1,5 +1,11 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
OCTop = "top"
|
||||
OCDomain = "domain"
|
||||
OCNSContainer = "nsContainer"
|
||||
)
|
||||
|
||||
const (
|
||||
OCGroup = "group"
|
||||
OCGroupOfUniqueNames = "groupOfUniqueNames"
|
||||
@ -19,3 +25,42 @@ const (
|
||||
OUGroups = "groups"
|
||||
OUVirtualGroups = "virtual-groups"
|
||||
)
|
||||
|
||||
func GetDomainOCs() map[string]bool {
|
||||
return map[string]bool{
|
||||
OCTop: true,
|
||||
OCDomain: true,
|
||||
}
|
||||
}
|
||||
|
||||
func GetContainerOCs() map[string]bool {
|
||||
return map[string]bool{
|
||||
OCTop: true,
|
||||
OCNSContainer: true,
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserOCs() map[string]bool {
|
||||
return map[string]bool{
|
||||
OCUser: true,
|
||||
OCOrgPerson: true,
|
||||
OCInetOrgPerson: true,
|
||||
OCAKUser: true,
|
||||
}
|
||||
}
|
||||
|
||||
func GetGroupOCs() map[string]bool {
|
||||
return map[string]bool{
|
||||
OCGroup: true,
|
||||
OCGroupOfUniqueNames: true,
|
||||
OCAKGroup: true,
|
||||
}
|
||||
}
|
||||
|
||||
func GetVirtualGroupOCs() map[string]bool {
|
||||
return map[string]bool{
|
||||
OCGroup: true,
|
||||
OCGroupOfUniqueNames: true,
|
||||
OCAKVirtualGroup: true,
|
||||
}
|
||||
}
|
||||
|
@ -2,14 +2,20 @@ package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/nmcclain/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
ldapConstants "goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
)
|
||||
|
||||
type ProviderInstance struct {
|
||||
@ -50,6 +56,10 @@ func (pi *ProviderInstance) GetBaseGroupDN() string {
|
||||
return pi.GroupDN
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetBaseVirtualGroupDN() string {
|
||||
return pi.VirtualGroupDN
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetBaseUserDN() string {
|
||||
return pi.UserDN
|
||||
}
|
||||
@ -82,3 +92,77 @@ func (pi *ProviderInstance) GetFlowSlug() string {
|
||||
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {
|
||||
return pi.searchAllowedGroups
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetBaseEntry() *ldap.Entry {
|
||||
return &ldap.Entry{
|
||||
DN: pi.GetBaseDN(),
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{
|
||||
Name: "distinguishedName",
|
||||
Values: []string{pi.GetBaseDN()},
|
||||
},
|
||||
{
|
||||
Name: "objectClass",
|
||||
Values: []string{ldapConstants.OCTop, ldapConstants.OCDomain},
|
||||
},
|
||||
{
|
||||
Name: "supportedLDAPVersion",
|
||||
Values: []string{"3"},
|
||||
},
|
||||
{
|
||||
Name: "namingContexts",
|
||||
Values: []string{
|
||||
pi.GetBaseDN(),
|
||||
pi.GetBaseUserDN(),
|
||||
pi.GetBaseGroupDN(),
|
||||
pi.GetBaseVirtualGroupDN(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vendorName",
|
||||
Values: []string{"goauthentik.io"},
|
||||
},
|
||||
{
|
||||
Name: "vendorVersion",
|
||||
Values: []string{fmt.Sprintf("authentik LDAP Outpost Version %s (build %s)", constants.VERSION, constants.BUILD())},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetNeededObjects(scope int, baseDN string, filterOC string) (bool, bool) {
|
||||
needUsers := false
|
||||
needGroups := false
|
||||
|
||||
// We only want to load users/groups if we're actually going to be asked
|
||||
// for at least one user or group based on the search's base DN and scope.
|
||||
//
|
||||
// If our requested base DN doesn't match any of the container DNs, then
|
||||
// we're probably loading a user or group. If it does, then make sure our
|
||||
// scope will eventually take us to users or groups.
|
||||
if (baseDN == pi.BaseDN || strings.HasSuffix(baseDN, pi.UserDN)) && utils.IncludeObjectClass(filterOC, ldapConstants.GetUserOCs()) {
|
||||
if baseDN != pi.UserDN && baseDN != pi.BaseDN ||
|
||||
baseDN == pi.BaseDN && scope > 1 ||
|
||||
baseDN == pi.UserDN && scope > 0 {
|
||||
needUsers = true
|
||||
}
|
||||
}
|
||||
|
||||
if (baseDN == pi.BaseDN || strings.HasSuffix(baseDN, pi.GroupDN)) && utils.IncludeObjectClass(filterOC, ldapConstants.GetGroupOCs()) {
|
||||
if baseDN != pi.GroupDN && baseDN != pi.BaseDN ||
|
||||
baseDN == pi.BaseDN && scope > 1 ||
|
||||
baseDN == pi.GroupDN && scope > 0 {
|
||||
needGroups = true
|
||||
}
|
||||
}
|
||||
|
||||
if (baseDN == pi.BaseDN || strings.HasSuffix(baseDN, pi.VirtualGroupDN)) && utils.IncludeObjectClass(filterOC, ldapConstants.GetVirtualGroupOCs()) {
|
||||
if baseDN != pi.VirtualGroupDN && baseDN != pi.BaseDN ||
|
||||
baseDN == pi.BaseDN && scope > 1 ||
|
||||
baseDN == pi.VirtualGroupDN && scope > 0 {
|
||||
needUsers = true
|
||||
}
|
||||
}
|
||||
|
||||
return needUsers, needGroups
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func (ls *LDAPServer) StartLDAPServer() error {
|
||||
|
||||
ln, err := net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
|
||||
ls.log.WithField("listen", listen).WithError(err).Fatalf("listen failed")
|
||||
}
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
defer proxyListener.Close()
|
||||
|
@ -37,7 +37,7 @@ func (ls *LDAPServer) StartLDAPTLSServer() error {
|
||||
|
||||
ln, err := net.Listen("tcp", listen)
|
||||
if err != nil {
|
||||
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
|
||||
ls.log.WithField("listen", listen).WithError(err).Fatalf("listen failed")
|
||||
}
|
||||
|
||||
proxyListener := &proxyproto.Listener{Listener: ln}
|
||||
|
@ -3,6 +3,7 @@ package ldap
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
goldap "github.com/go-ldap/ldap/v3"
|
||||
@ -41,13 +42,13 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
|
||||
if searchReq.BaseDN == "" {
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
bd, err := goldap.ParseDN(searchReq.BaseDN)
|
||||
bd, err := goldap.ParseDN(strings.ToLower(searchReq.BaseDN))
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Info("failed to parse basedn")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("invalid DN")
|
||||
}
|
||||
for _, provider := range ls.providers {
|
||||
providerBase, _ := goldap.ParseDN(provider.BaseDN)
|
||||
providerBase, _ := goldap.ParseDN(strings.ToLower(provider.BaseDN))
|
||||
if providerBase.AncestorOf(bd) || providerBase.Equal(bd) {
|
||||
return provider.searcher.Search(req)
|
||||
}
|
||||
|
@ -4,16 +4,15 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
@ -35,26 +34,11 @@ func NewDirectSearcher(si server.LDAPServerInstance) *DirectSearcher {
|
||||
return ds
|
||||
}
|
||||
|
||||
func (ds *DirectSearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
|
||||
if f.UserInfo == nil {
|
||||
u, _, err := ds.si.GetAPIClient().CoreApi.CoreUsersRetrieve(req.Context(), f.UserPk).Execute()
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("Failed to get user info")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
f.UserInfo = &u
|
||||
}
|
||||
entries := make([]*ldap.Entry, 1)
|
||||
entries[0] = ds.si.UserEntry(*f.UserInfo)
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
||||
func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
|
||||
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
|
||||
baseDN := strings.ToLower("," + ds.si.GetBaseDN())
|
||||
baseDN := strings.ToLower(ds.si.GetBaseDN())
|
||||
|
||||
entries := []*ldap.Entry{}
|
||||
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
filterOC, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
@ -75,7 +59,7 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
|
||||
}
|
||||
if !strings.HasSuffix(req.BindDN, baseDN) {
|
||||
if !strings.HasSuffix(req.BindDN, ","+baseDN) {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
@ -98,15 +82,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
|
||||
}
|
||||
|
||||
if req.Scope == ldap.ScopeBaseObject {
|
||||
req.Log().Debug("base scope, showing domain info")
|
||||
return ds.SearchBase(req, flags.CanSearch)
|
||||
}
|
||||
if !flags.CanSearch {
|
||||
req.Log().Debug("User can't search, showing info about user")
|
||||
return ds.SearchMe(req, flags)
|
||||
}
|
||||
accsp.Finish()
|
||||
|
||||
parsedFilter, err := ldap.CompileFilter(req.Filter)
|
||||
@ -121,99 +96,176 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
|
||||
}
|
||||
|
||||
entries := make([]*ldap.Entry, 0)
|
||||
|
||||
// Create a custom client to set additional headers
|
||||
c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig())
|
||||
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
|
||||
|
||||
switch filterEntity {
|
||||
default:
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ds.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "unhandled_filter_type",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
|
||||
case constants.OCGroupOfUniqueNames:
|
||||
fallthrough
|
||||
case constants.OCAKGroup:
|
||||
fallthrough
|
||||
case constants.OCAKVirtualGroup:
|
||||
fallthrough
|
||||
case constants.OCGroup:
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
scope := req.SearchRequest.Scope
|
||||
needUsers, needGroups := ds.si.GetNeededObjects(scope, req.BaseDN, filterOC)
|
||||
|
||||
gEntries := make([]*ldap.Entry, 0)
|
||||
uEntries := make([]*ldap.Entry, 0)
|
||||
if scope >= 0 && req.BaseDN == baseDN {
|
||||
if utils.IncludeObjectClass(filterOC, constants.GetDomainOCs()) {
|
||||
entries = append(entries, ds.si.GetBaseEntry())
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
scope -= 1 // Bring it from WholeSubtree to SingleLevel and so on
|
||||
}
|
||||
|
||||
var users *[]api.User
|
||||
var groups *[]api.Group
|
||||
|
||||
errs, _ := errgroup.WithContext(req.Context())
|
||||
|
||||
if needUsers {
|
||||
errs.Go(func() error {
|
||||
if flags.CanSearch {
|
||||
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return nil
|
||||
}
|
||||
|
||||
u, _, e := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("failed to get users")
|
||||
return e
|
||||
}
|
||||
|
||||
users = &u.Results
|
||||
} else {
|
||||
if flags.UserInfo == nil {
|
||||
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
|
||||
u, _, err := c.CoreApi.CoreUsersRetrieve(req.Context(), flags.UserPk).Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("Failed to get user info")
|
||||
return fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
|
||||
flags.UserInfo = &u
|
||||
}
|
||||
|
||||
u := make([]api.User, 1)
|
||||
u[0] = *flags.UserInfo
|
||||
|
||||
users = &u
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if needGroups {
|
||||
errs.Go(func() error {
|
||||
gapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_group")
|
||||
searchReq, skip := utils.ParseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
groups, _, err := searchReq.Execute()
|
||||
|
||||
if !flags.CanSearch {
|
||||
// If they can't search, filter all groups by those they're a member of
|
||||
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
|
||||
}
|
||||
|
||||
g, _, err := searchReq.Execute()
|
||||
gapisp.Finish()
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("failed to get groups")
|
||||
return
|
||||
return err
|
||||
}
|
||||
req.Log().WithField("count", len(groups.Results)).Trace("Got results from API")
|
||||
req.Log().WithField("count", len(g.Results)).Trace("Got results from API")
|
||||
|
||||
for _, g := range groups.Results {
|
||||
gEntries = append(gEntries, group.FromAPIGroup(g, ds.si).Entry())
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return
|
||||
}
|
||||
users, _, err := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
if err != nil {
|
||||
req.Log().WithError(err).Warning("failed to get users")
|
||||
return
|
||||
if !flags.CanSearch {
|
||||
for i, results := range g.Results {
|
||||
// If they can't search, remove any users from the group results except the one we're looking for.
|
||||
g.Results[i].Users = []int32{flags.UserPk}
|
||||
for _, u := range results.UsersObj {
|
||||
if u.Pk == flags.UserPk {
|
||||
g.Results[i].UsersObj = []api.GroupMember{u}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range users.Results {
|
||||
uEntries = append(uEntries, group.FromAPIUser(u, ds.si).Entry())
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
entries = append(gEntries, uEntries...)
|
||||
case "":
|
||||
fallthrough
|
||||
case constants.OCOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCInetOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCAKUser:
|
||||
fallthrough
|
||||
case constants.OCUser:
|
||||
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
|
||||
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
|
||||
if skip {
|
||||
req.Log().Trace("Skip backend request")
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
groups = &g.Results
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = errs.Wait()
|
||||
|
||||
if err != nil {
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, err
|
||||
}
|
||||
|
||||
if scope >= 0 && (req.BaseDN == ds.si.GetBaseDN() || strings.HasSuffix(req.BaseDN, ds.si.GetBaseUserDN())) {
|
||||
singleu := strings.HasSuffix(req.BaseDN, ","+ds.si.GetBaseUserDN())
|
||||
|
||||
if !singleu && utils.IncludeObjectClass(filterOC, constants.GetContainerOCs()) {
|
||||
entries = append(entries, utils.GetContainerEntry(filterOC, ds.si.GetBaseUserDN(), constants.OUUsers))
|
||||
scope -= 1
|
||||
}
|
||||
users, _, err := searchReq.Execute()
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
|
||||
if scope >= 0 && users != nil && utils.IncludeObjectClass(filterOC, constants.GetUserOCs()) {
|
||||
for _, u := range *users {
|
||||
entry := ds.si.UserEntry(u)
|
||||
if req.BaseDN == entry.DN || !singleu {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, u := range users.Results {
|
||||
entries = append(entries, ds.si.UserEntry(u))
|
||||
|
||||
scope += 1 // Return the scope to what it was before we descended
|
||||
}
|
||||
|
||||
if scope >= 0 && (req.BaseDN == ds.si.GetBaseDN() || strings.HasSuffix(req.BaseDN, ds.si.GetBaseGroupDN())) {
|
||||
singleg := strings.HasSuffix(req.BaseDN, ","+ds.si.GetBaseGroupDN())
|
||||
|
||||
if !singleg && utils.IncludeObjectClass(filterOC, constants.GetContainerOCs()) {
|
||||
entries = append(entries, utils.GetContainerEntry(filterOC, ds.si.GetBaseGroupDN(), constants.OUGroups))
|
||||
scope -= 1
|
||||
}
|
||||
|
||||
if scope >= 0 && groups != nil && utils.IncludeObjectClass(filterOC, constants.GetGroupOCs()) {
|
||||
for _, g := range *groups {
|
||||
entry := group.FromAPIGroup(g, ds.si).Entry()
|
||||
if req.BaseDN == entry.DN || !singleg {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scope += 1 // Return the scope to what it was before we descended
|
||||
}
|
||||
|
||||
if scope >= 0 && (req.BaseDN == ds.si.GetBaseDN() || strings.HasSuffix(req.BaseDN, ds.si.GetBaseVirtualGroupDN())) {
|
||||
singlevg := strings.HasSuffix(req.BaseDN, ","+ds.si.GetBaseVirtualGroupDN())
|
||||
|
||||
if !singlevg || utils.IncludeObjectClass(filterOC, constants.GetContainerOCs()) {
|
||||
entries = append(entries, utils.GetContainerEntry(filterOC, ds.si.GetBaseVirtualGroupDN(), constants.OUVirtualGroups))
|
||||
scope -= 1
|
||||
}
|
||||
|
||||
if scope >= 0 && users != nil && utils.IncludeObjectClass(filterOC, constants.GetVirtualGroupOCs()) {
|
||||
for _, u := range *users {
|
||||
entry := group.FromAPIUser(u, ds.si).Entry()
|
||||
if req.BaseDN == entry.DN || !singlevg {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
@ -1,54 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
)
|
||||
|
||||
func (ms *MemorySearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
|
||||
dn := ""
|
||||
if authz {
|
||||
dn = req.SearchRequest.BaseDN
|
||||
}
|
||||
return ldap.ServerSearchResult{
|
||||
Entries: []*ldap.Entry{
|
||||
{
|
||||
DN: dn,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{
|
||||
Name: "distinguishedName",
|
||||
Values: []string{ms.si.GetBaseDN()},
|
||||
},
|
||||
{
|
||||
Name: "objectClass",
|
||||
Values: []string{"top", "domain"},
|
||||
},
|
||||
{
|
||||
Name: "supportedLDAPVersion",
|
||||
Values: []string{"3"},
|
||||
},
|
||||
{
|
||||
Name: "namingContexts",
|
||||
Values: []string{
|
||||
ms.si.GetBaseDN(),
|
||||
ms.si.GetBaseUserDN(),
|
||||
ms.si.GetBaseGroupDN(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vendorName",
|
||||
Values: []string{"goauthentik.io"},
|
||||
},
|
||||
{
|
||||
Name: "vendorVersion",
|
||||
Values: []string{fmt.Sprintf("authentik LDAP Outpost Version %s (build %s)", constants.VERSION, constants.BUILD())},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess,
|
||||
}, nil
|
||||
}
|
@ -11,11 +11,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
)
|
||||
|
||||
type MemorySearcher struct {
|
||||
@ -37,29 +37,11 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
|
||||
return ms
|
||||
}
|
||||
|
||||
func (ms *MemorySearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
|
||||
if f.UserInfo == nil {
|
||||
for _, u := range ms.users {
|
||||
if u.Pk == f.UserPk {
|
||||
f.UserInfo = &u
|
||||
}
|
||||
}
|
||||
if f.UserInfo == nil {
|
||||
req.Log().WithField("pk", f.UserPk).Warning("User with pk is not in local cache")
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
}
|
||||
entries := make([]*ldap.Entry, 1)
|
||||
entries[0] = ms.si.UserEntry(*f.UserInfo)
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
||||
func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
|
||||
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
|
||||
baseDN := strings.ToLower("," + ms.si.GetBaseDN())
|
||||
baseDN := strings.ToLower(ms.si.GetBaseDN())
|
||||
|
||||
entries := []*ldap.Entry{}
|
||||
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
filterOC, err := ldap.GetFilterObjectClass(req.Filter)
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
@ -80,7 +62,7 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
|
||||
}
|
||||
if !strings.HasSuffix(req.BindDN, baseDN) {
|
||||
if !strings.HasSuffix(req.BindDN, ","+baseDN) {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
@ -103,52 +85,132 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
|
||||
}
|
||||
|
||||
if req.Scope == ldap.ScopeBaseObject {
|
||||
req.Log().Debug("base scope, showing domain info")
|
||||
return ms.SearchBase(req, flags.CanSearch)
|
||||
}
|
||||
if !flags.CanSearch {
|
||||
req.Log().Debug("User can't search, showing info about user")
|
||||
return ms.SearchMe(req, flags)
|
||||
}
|
||||
accsp.Finish()
|
||||
|
||||
switch filterEntity {
|
||||
default:
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": ms.si.GetOutpostName(),
|
||||
"type": "search",
|
||||
"reason": "unhandled_filter_type",
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
|
||||
case constants.OCGroupOfUniqueNames:
|
||||
fallthrough
|
||||
case constants.OCAKGroup:
|
||||
fallthrough
|
||||
case constants.OCAKVirtualGroup:
|
||||
fallthrough
|
||||
case constants.OCGroup:
|
||||
for _, g := range ms.groups {
|
||||
entries = append(entries, group.FromAPIGroup(g, ms.si).Entry())
|
||||
entries := make([]*ldap.Entry, 0)
|
||||
|
||||
scope := req.SearchRequest.Scope
|
||||
needUsers, needGroups := ms.si.GetNeededObjects(scope, req.BaseDN, filterOC)
|
||||
|
||||
if scope >= 0 && req.BaseDN == baseDN {
|
||||
if utils.IncludeObjectClass(filterOC, constants.GetDomainOCs()) {
|
||||
entries = append(entries, ms.si.GetBaseEntry())
|
||||
}
|
||||
for _, u := range ms.users {
|
||||
entries = append(entries, group.FromAPIUser(u, ms.si).Entry())
|
||||
}
|
||||
case "":
|
||||
fallthrough
|
||||
case constants.OCOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCInetOrgPerson:
|
||||
fallthrough
|
||||
case constants.OCAKUser:
|
||||
fallthrough
|
||||
case constants.OCUser:
|
||||
for _, u := range ms.users {
|
||||
entries = append(entries, ms.si.UserEntry(u))
|
||||
|
||||
scope -= 1 // Bring it from WholeSubtree to SingleLevel and so on
|
||||
}
|
||||
|
||||
var users *[]api.User
|
||||
var groups []*group.LDAPGroup
|
||||
|
||||
if needUsers {
|
||||
if flags.CanSearch {
|
||||
users = &ms.users
|
||||
} else {
|
||||
if flags.UserInfo == nil {
|
||||
for i, u := range ms.users {
|
||||
if u.Pk == flags.UserPk {
|
||||
flags.UserInfo = &ms.users[i]
|
||||
}
|
||||
}
|
||||
|
||||
if flags.UserInfo == nil {
|
||||
req.Log().WithField("pk", flags.UserPk).Warning("User with pk is not in local cache")
|
||||
err = fmt.Errorf("failed to get userinfo")
|
||||
}
|
||||
}
|
||||
|
||||
u := make([]api.User, 1)
|
||||
u[0] = *flags.UserInfo
|
||||
|
||||
users = &u
|
||||
}
|
||||
}
|
||||
|
||||
if needGroups {
|
||||
groups = make([]*group.LDAPGroup, 0)
|
||||
|
||||
for _, g := range ms.groups {
|
||||
if flags.CanSearch {
|
||||
groups = append(groups, group.FromAPIGroup(g, ms.si))
|
||||
} else {
|
||||
// If the user cannot search, we're going to only return
|
||||
// the groups they're in _and_ only return themselves
|
||||
// as a member.
|
||||
for _, u := range g.UsersObj {
|
||||
if flags.UserPk == u.Pk {
|
||||
// TODO: Is there a better way to clone this object?
|
||||
fg := api.NewGroup(g.Pk, g.Name, g.Parent, g.ParentName, []int32{flags.UserPk}, []api.GroupMember{u})
|
||||
fg.SetAttributes(*g.Attributes)
|
||||
fg.SetIsSuperuser(*g.IsSuperuser)
|
||||
groups = append(groups, group.FromAPIGroup(*fg, ms.si))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, err
|
||||
}
|
||||
|
||||
if scope >= 0 && (req.BaseDN == ms.si.GetBaseDN() || strings.HasSuffix(req.BaseDN, ms.si.GetBaseUserDN())) {
|
||||
singleu := strings.HasSuffix(req.BaseDN, ","+ms.si.GetBaseUserDN())
|
||||
|
||||
if !singleu && utils.IncludeObjectClass(filterOC, constants.GetContainerOCs()) {
|
||||
entries = append(entries, utils.GetContainerEntry(filterOC, ms.si.GetBaseUserDN(), constants.OUUsers))
|
||||
scope -= 1
|
||||
}
|
||||
|
||||
if scope >= 0 && users != nil && utils.IncludeObjectClass(filterOC, constants.GetUserOCs()) {
|
||||
for _, u := range *users {
|
||||
entry := ms.si.UserEntry(u)
|
||||
if req.BaseDN == entry.DN || !singleu {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scope += 1 // Return the scope to what it was before we descended
|
||||
}
|
||||
|
||||
if scope >= 0 && (req.BaseDN == ms.si.GetBaseDN() || strings.HasSuffix(req.BaseDN, ms.si.GetBaseGroupDN())) {
|
||||
singleg := strings.HasSuffix(req.BaseDN, ","+ms.si.GetBaseGroupDN())
|
||||
|
||||
if !singleg && utils.IncludeObjectClass(filterOC, constants.GetContainerOCs()) {
|
||||
entries = append(entries, utils.GetContainerEntry(filterOC, ms.si.GetBaseGroupDN(), constants.OUGroups))
|
||||
scope -= 1
|
||||
}
|
||||
|
||||
if scope >= 0 && groups != nil && utils.IncludeObjectClass(filterOC, constants.GetGroupOCs()) {
|
||||
for _, g := range groups {
|
||||
if req.BaseDN == g.DN || !singleg {
|
||||
entries = append(entries, g.Entry())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scope += 1 // Return the scope to what it was before we descended
|
||||
}
|
||||
|
||||
if scope >= 0 && (req.BaseDN == ms.si.GetBaseDN() || strings.HasSuffix(req.BaseDN, ms.si.GetBaseVirtualGroupDN())) {
|
||||
singlevg := strings.HasSuffix(req.BaseDN, ","+ms.si.GetBaseVirtualGroupDN())
|
||||
|
||||
if !singlevg && utils.IncludeObjectClass(filterOC, constants.GetContainerOCs()) {
|
||||
entries = append(entries, utils.GetContainerEntry(filterOC, ms.si.GetBaseVirtualGroupDN(), constants.OUVirtualGroups))
|
||||
scope -= 1
|
||||
}
|
||||
|
||||
if scope >= 0 && users != nil && utils.IncludeObjectClass(filterOC, constants.GetVirtualGroupOCs()) {
|
||||
for _, u := range *users {
|
||||
entry := group.FromAPIUser(u, ms.si).Entry()
|
||||
if req.BaseDN == entry.DN || !singlevg {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ type Request struct {
|
||||
func NewRequest(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (*Request, *sentry.Span) {
|
||||
rid := uuid.New().String()
|
||||
bindDN = strings.ToLower(bindDN)
|
||||
searchReq.BaseDN = strings.ToLower(searchReq.BaseDN)
|
||||
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search"))
|
||||
span.SetTag("request_uid", rid)
|
||||
span.SetTag("user.username", bindDN)
|
||||
|
@ -1,6 +1,8 @@
|
||||
package search
|
||||
|
||||
import "github.com/nmcclain/ldap"
|
||||
import (
|
||||
"github.com/nmcclain/ldap"
|
||||
)
|
||||
|
||||
type Searcher interface {
|
||||
Search(req *Request) (ldap.ServerSearchResult, error)
|
||||
|
@ -19,6 +19,7 @@ type LDAPServerInstance interface {
|
||||
|
||||
GetBaseDN() string
|
||||
GetBaseGroupDN() string
|
||||
GetBaseVirtualGroupDN() string
|
||||
GetBaseUserDN() string
|
||||
|
||||
GetUserDN(string) string
|
||||
@ -32,4 +33,7 @@ type LDAPServerInstance interface {
|
||||
|
||||
GetFlags(string) (flags.UserFlags, bool)
|
||||
SetFlags(string, flags.UserFlags)
|
||||
|
||||
GetBaseEntry() *ldap.Entry
|
||||
GetNeededObjects(int, string, string) (bool, bool)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/nmcclain/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
ldapConstants "goauthentik.io/internal/outpost/ldap/constants"
|
||||
)
|
||||
|
||||
func BoolToString(in bool) string {
|
||||
@ -84,3 +85,35 @@ func MustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
|
||||
func IncludeObjectClass(searchOC string, ocs map[string]bool) bool {
|
||||
if searchOC == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return ocs[searchOC]
|
||||
}
|
||||
|
||||
func GetContainerEntry(filterOC string, dn string, ou string) *ldap.Entry {
|
||||
if IncludeObjectClass(filterOC, ldapConstants.GetContainerOCs()) {
|
||||
return &ldap.Entry{
|
||||
DN: dn,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{
|
||||
Name: "distinguishedName",
|
||||
Values: []string{dn},
|
||||
},
|
||||
{
|
||||
Name: "objectClass",
|
||||
Values: []string{"top", "nsContainer"},
|
||||
},
|
||||
{
|
||||
Name: "commonName",
|
||||
Values: []string{ou},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
@ -52,11 +53,17 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, cs *ak.CryptoStore
|
||||
return nil, fmt.Errorf("failed to parse URL, skipping provider")
|
||||
}
|
||||
|
||||
ks := hs256.NewKeySet(*p.ClientSecret)
|
||||
var ks oidc.KeySet
|
||||
if contains(p.OidcConfiguration.IdTokenSigningAlgValuesSupported, "HS256") {
|
||||
ks = hs256.NewKeySet(*p.ClientSecret)
|
||||
} else {
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, c)
|
||||
ks = oidc.NewRemoteKeySet(ctx, p.OidcConfiguration.JwksUri)
|
||||
}
|
||||
|
||||
var verifier = oidc.NewVerifier(p.OidcConfiguration.Issuer, ks, &oidc.Config{
|
||||
ClientID: *p.ClientId,
|
||||
SupportedSigningAlgs: []string{"HS256"},
|
||||
SupportedSigningAlgs: []string{"RS256", "HS256"},
|
||||
})
|
||||
|
||||
// Configure an OpenID Connect aware OAuth2 client.
|
||||
@ -94,14 +101,14 @@ func NewApplication(p api.ProxyOutpostConfig, c *http.Client, cs *ak.CryptoStore
|
||||
if !ok {
|
||||
return l
|
||||
}
|
||||
return l.WithField("request_username", c.Email)
|
||||
return l.WithField("request_username", c.PreferredUsername)
|
||||
}))
|
||||
mux.Use(func(inner http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
c, _ := a.getClaims(r)
|
||||
user := ""
|
||||
if c != nil {
|
||||
user = c.Email
|
||||
user = c.PreferredUsername
|
||||
}
|
||||
before := time.Now()
|
||||
inner.ServeHTTP(rw, r)
|
||||
|
@ -13,4 +13,6 @@ type Claims struct {
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Groups []string `json:"groups"`
|
||||
|
||||
RawToken string
|
||||
}
|
||||
|
@ -5,24 +5,34 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"goauthentik.io/internal/constants"
|
||||
)
|
||||
|
||||
func (a *Application) addHeaders(r *http.Request, c *Claims) {
|
||||
func (a *Application) addHeaders(headers http.Header, c *Claims) {
|
||||
// https://goauthentik.io/docs/providers/proxy/proxy
|
||||
|
||||
// Legacy headers, remove after 2022.1
|
||||
r.Header.Set("X-Auth-Username", c.PreferredUsername)
|
||||
r.Header.Set("X-Auth-Groups", strings.Join(c.Groups, "|"))
|
||||
r.Header.Set("X-Forwarded-Email", c.Email)
|
||||
r.Header.Set("X-Forwarded-Preferred-Username", c.PreferredUsername)
|
||||
r.Header.Set("X-Forwarded-User", c.Sub)
|
||||
headers.Set("X-Auth-Username", c.PreferredUsername)
|
||||
headers.Set("X-Auth-Groups", strings.Join(c.Groups, "|"))
|
||||
headers.Set("X-Forwarded-Email", c.Email)
|
||||
headers.Set("X-Forwarded-Preferred-Username", c.PreferredUsername)
|
||||
headers.Set("X-Forwarded-User", c.Sub)
|
||||
|
||||
// New headers, unique prefix
|
||||
r.Header.Set("X-authentik-username", c.PreferredUsername)
|
||||
r.Header.Set("X-authentik-groups", strings.Join(c.Groups, "|"))
|
||||
r.Header.Set("X-authentik-email", c.Email)
|
||||
r.Header.Set("X-authentik-name", c.Name)
|
||||
r.Header.Set("X-authentik-uid", c.Sub)
|
||||
headers.Set("X-authentik-username", c.PreferredUsername)
|
||||
headers.Set("X-authentik-groups", strings.Join(c.Groups, "|"))
|
||||
headers.Set("X-authentik-email", c.Email)
|
||||
headers.Set("X-authentik-name", c.Name)
|
||||
headers.Set("X-authentik-uid", c.Sub)
|
||||
headers.Set("X-authentik-jwt", c.RawToken)
|
||||
|
||||
// System headers
|
||||
headers.Set("X-authentik-meta-jwks", a.proxyConfig.OidcConfiguration.JwksUri)
|
||||
headers.Set("X-authentik-meta-outpost", a.outpostName)
|
||||
headers.Set("X-authentik-meta-provider", a.proxyConfig.Name)
|
||||
headers.Set("X-authentik-meta-app", a.proxyConfig.AssignedApplicationSlug)
|
||||
headers.Set("X-authentik-meta-version", constants.OutpostUserAgent())
|
||||
|
||||
userAttributes := c.Proxy.UserAttributes
|
||||
// Attempt to set basic auth based on user's attributes
|
||||
@ -39,7 +49,7 @@ func (a *Application) addHeaders(r *http.Request, c *Claims) {
|
||||
}
|
||||
authVal := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
a.log.WithField("username", username).Trace("setting http basic auth")
|
||||
r.Header["Authorization"] = []string{fmt.Sprintf("Basic %s", authVal)}
|
||||
headers.Set("Authorization", fmt.Sprintf("Basic %s", authVal))
|
||||
}
|
||||
// Check if user has additional headers set that we should sent
|
||||
if additionalHeaders, ok := userAttributes["additionalHeaders"].(map[string]interface{}); ok {
|
||||
@ -48,15 +58,7 @@ func (a *Application) addHeaders(r *http.Request, c *Claims) {
|
||||
return
|
||||
}
|
||||
for key, value := range additionalHeaders {
|
||||
r.Header.Set(key, toString(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeadersToResponse(rw http.ResponseWriter, r *http.Request) {
|
||||
for headerKey, headers := range r.Header {
|
||||
for _, value := range headers {
|
||||
rw.Header().Set(headerKey, value)
|
||||
headers.Set(key, toString(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -26,8 +26,9 @@ func (a *Application) configureForward() error {
|
||||
func (a *Application) forwardHandleTraefik(rw http.ResponseWriter, r *http.Request) {
|
||||
claims, err := a.getClaims(r)
|
||||
if claims != nil && err == nil {
|
||||
a.addHeaders(r, claims)
|
||||
copyHeadersToResponse(rw, r)
|
||||
a.addHeaders(rw.Header(), claims)
|
||||
rw.Header().Set("User-Agent", r.Header.Get("User-Agent"))
|
||||
a.log.WithField("headers", rw.Header()).Trace("headers written to forward_auth")
|
||||
return
|
||||
} else if claims == nil && a.IsAllowlisted(r) {
|
||||
a.log.Trace("path can be accessed without authentication")
|
||||
@ -69,9 +70,10 @@ func (a *Application) forwardHandleTraefik(rw http.ResponseWriter, r *http.Reque
|
||||
func (a *Application) forwardHandleNginx(rw http.ResponseWriter, r *http.Request) {
|
||||
claims, err := a.getClaims(r)
|
||||
if claims != nil && err == nil {
|
||||
a.addHeaders(r, claims)
|
||||
copyHeadersToResponse(rw, r)
|
||||
a.addHeaders(rw.Header(), claims)
|
||||
rw.Header().Set("User-Agent", r.Header.Get("User-Agent"))
|
||||
rw.WriteHeader(200)
|
||||
a.log.WithField("headers", rw.Header()).Trace("headers written to forward_auth")
|
||||
return
|
||||
} else if claims == nil && a.IsAllowlisted(r) {
|
||||
a.log.Trace("path can be accessed without authentication")
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/proxyv2/metrics"
|
||||
@ -28,7 +29,8 @@ func (a *Application) configureProxy() error {
|
||||
return err
|
||||
}
|
||||
rp := &httputil.ReverseProxy{Director: a.proxyModifyRequest(u)}
|
||||
rp.Transport = ak.NewTracingTransport(context.TODO(), a.getUpstreamTransport())
|
||||
rsp := sentry.StartSpan(context.TODO(), "authentik.outposts.proxy.application_transport")
|
||||
rp.Transport = ak.NewTracingTransport(rsp.Context(), a.getUpstreamTransport())
|
||||
rp.ErrorHandler = a.newProxyErrorHandler(templates.GetTemplates())
|
||||
rp.ModifyResponse = a.proxyModifyResponse
|
||||
a.mux.PathPrefix("/").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -39,7 +41,7 @@ func (a *Application) configureProxy() error {
|
||||
a.redirectToStart(rw, r)
|
||||
return
|
||||
} else {
|
||||
a.addHeaders(r, claims)
|
||||
a.addHeaders(r.Header, claims)
|
||||
}
|
||||
before := time.Now()
|
||||
rp.ServeHTTP(rw, r)
|
||||
|
@ -45,5 +45,6 @@ func (a *Application) redeemCallback(r *http.Request, shouldState string) (*Clai
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims.RawToken = rawIDToken
|
||||
return claims, nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package application
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
@ -26,14 +27,14 @@ func (a *Application) getStore(p api.ProxyOutpostConfig) sessions.Store {
|
||||
a.log.Info("using redis session backend")
|
||||
store = rs
|
||||
} else {
|
||||
cs := sessions.NewCookieStore([]byte(*p.CookieSecret))
|
||||
cs := sessions.NewFilesystemStore(os.TempDir(), []byte(*p.CookieSecret))
|
||||
cs.Options.Domain = *p.CookieDomain
|
||||
if p.TokenValidity.IsSet() {
|
||||
t := p.TokenValidity.Get()
|
||||
// Add one to the validity to ensure we don't have a session with indefinite length
|
||||
cs.Options.MaxAge = int(*t) + 1
|
||||
}
|
||||
a.log.Info("using cookie session backend")
|
||||
a.log.Info("using filesystem session backend")
|
||||
store = cs
|
||||
}
|
||||
return store
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user