diff --git a/authentik/tasks/management/commands/worker.py b/authentik/tasks/management/commands/worker.py index e45faa56e1..71441db7eb 100644 --- a/authentik/tasks/management/commands/worker.py +++ b/authentik/tasks/management/commands/worker.py @@ -11,6 +11,13 @@ class Command(BaseCommand): """Run worker""" def add_arguments(self, parser): + parser.add_argument( + "--pid-file", + action="store", + default=None, + dest="pid_file", + help="PID file", + ) parser.add_argument( "--reload", action="store_true", @@ -44,14 +51,26 @@ class Command(BaseCommand): ) def handle( - self, use_watcher, use_polling_watcher, use_gevent, processes, threads, verbosity, **options + self, + pid_file, + use_watcher, + use_polling_watcher, + use_gevent, + processes, + threads, + verbosity, + **options, ): executable_name = "dramatiq-gevent" if use_gevent else "dramatiq" executable_path = self._resolve_executable(executable_name) - watch_args = ["--watch", "."] if use_watcher else [] + watch_args = ["--watch", "authentik"] if use_watcher else [] if watch_args and use_polling_watcher: watch_args.append("--watch-use-polling") + pid_file_args = [] + if pid_file is not None: + pid_file_args = ["--pid-file", pid_file] + verbosity_args = ["-v"] * (verbosity - 1) tasks_modules = self._discover_tasks_modules() @@ -64,6 +83,7 @@ class Command(BaseCommand): "--threads", str(threads), *watch_args, + *pid_file_args, *verbosity_args, *tasks_modules, ] diff --git a/cmd/server/server.go b/cmd/server/server.go index 4091c20957..c14f2c49d6 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/url" + "syscall" "time" "github.com/getsentry/sentry-go" @@ -19,6 +20,7 @@ import ( sentryutils "goauthentik.io/internal/utils/sentry" webutils "goauthentik.io/internal/utils/web" "goauthentik.io/internal/web" + "goauthentik.io/internal/worker" ) var rootCmd = &cobra.Command{ @@ -65,6 +67,14 @@ var rootCmd = &cobra.Command{ panic(err) } + worker := worker.New() + if config.Get().Worker.Embedded { + err = worker.Start() + if err != nil { + panic(err) + } + } + ws := web.NewWebServer() ws.Core().AddHealthyCallback(func() { if config.Get().Outposts.DisableEmbeddedOutpost { @@ -76,6 +86,7 @@ var rootCmd = &cobra.Command{ <-ex l.Info("shutting down webserver") go ws.Shutdown() + go worker.Kill(syscall.SIGTERM) }, } diff --git a/internal/config/struct.go b/internal/config/struct.go index 05e544d844..8eff328de3 100644 --- a/internal/config/struct.go +++ b/internal/config/struct.go @@ -80,5 +80,5 @@ type WebConfig struct { } type WorkerConfig struct { - Embedded string `yaml:"embedded" env:"EMBEDDED, overwrite"` + Embedded bool `yaml:"embedded" env:"EMBEDDED, overwrite"` } diff --git a/internal/web/web.go b/internal/web/web.go index af49d8b2d5..fcd16543e2 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -27,7 +27,6 @@ import ( "goauthentik.io/internal/utils" "goauthentik.io/internal/utils/web" "goauthentik.io/internal/web/brand_tls" - "goauthentik.io/internal/worker" ) const ( @@ -47,7 +46,6 @@ type WebServer struct { g *gounicorn.GoUnicorn gunicornReady bool - worker *worker.Worker mainRouter *mux.Router loggingRouter *mux.Router log *log.Entry @@ -172,7 +170,6 @@ func (ws *WebServer) Start() { go ws.runMetricsServer() go ws.attemptStartBackend() - go ws.attemptStartWorker() go ws.listenPlain() go ws.listenTLS() } @@ -202,12 +199,6 @@ func (ws *WebServer) attemptStartBackend() { } } -func (ws *WebServer) attemptStartWorker() { - if ws.worker == nil { - return - } -} - func (ws *WebServer) Core() *gounicorn.GoUnicorn { return ws.g } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 6bc92b0c64..b1ff348614 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -6,15 +6,11 @@ import ( "os/exec" "os/signal" "runtime" - "strconv" - "strings" "syscall" - "time" log "github.com/sirupsen/logrus" "goauthentik.io/internal/config" - "goauthentik.io/internal/utils" ) type Worker struct { @@ -26,30 +22,30 @@ type Worker struct { pidFile string started bool killed bool - alive bool } -func New(healthcheck func() bool) *Worker { +func New() *Worker { logger := log.WithField("logger", "authentik.router.worker") w := &Worker{ - Healthcheck: healthcheck, - log: logger, - started: false, - killed: false, - alive: false, - HealthyCallback: func() {}, + log: logger, + started: false, + killed: false, } w.initCmd() c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGHUP, syscall.SIGUSR2) + signal.Notify(c, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) go func() { for sig := range c { - if sig == syscall.SIGHUP { - w.log.Info("SIGHUP received, forwarding to gunicorn") + switch sig { + case syscall.SIGHUP: + w.log.Info("SIGHUP received, forwarding to dramatiq") w.Reload() - } else if sig == syscall.SIGUSR2 { - w.log.Info("SIGUSR2 received, restarting gunicorn") - w.Restart() + case syscall.SIGINT: + w.log.Info("SIGINT received, stopping dramatiq") + w.Kill(syscall.SIGINT) + case syscall.SIGTERM: + w.log.Info("SIGTERM received, stopping dramatiq") + w.Kill(syscall.SIGTERM) } } }() @@ -58,135 +54,56 @@ func New(healthcheck func() bool) *Worker { func (w *Worker) initCmd() { command := "./manage.py" - args := []string{"dev_server"} - if !config.Get().Debug { - pidFile, err := os.CreateTemp("", "authentik-gunicorn.*.pid") - if err != nil { - panic(fmt.Errorf("failed to create temporary pid file: %v", err)) - } - w.pidFile = pidFile.Name() - command = "gunicorn" - args = []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} - if w.pidFile != "" { - args = append(args, "--pid", w.pidFile) - } + args := []string{"worker"} + if config.Get().Debug { + args = append(args, "--reload") } - w.log.WithField("args", args).WithField("cmd", command).Debug("Starting gunicorn") + + pidFile, err := os.CreateTemp("", "authentik-dramatiq.pid") + if err != nil { + panic(fmt.Errorf("failed to create temporary pid file: %v", err)) + } + w.pidFile = pidFile.Name() + args = append(args, "--pid-file", w.pidFile) + + w.log.WithField("args", args).WithField("cmd", command).Debug("Starting dramatiq") w.p = exec.Command(command, args...) w.p.Env = os.Environ() w.p.Stdout = os.Stdout w.p.Stderr = os.Stderr } -func (w *Worker) IsRunning() bool { - return w.alive -} - func (w *Worker) Start() error { - if w.started { + if !w.started { w.initCmd() } w.killed = false w.started = true - go w.healthcheck() return w.p.Run() } -func (w *Worker) healthcheck() { - w.log.Debug("starting healthcheck") - // Default healthcheck is every 1 second on startup - // once we've been healthy once, increase to 30 seconds - for range time.NewTicker(time.Second).C { - if w.Healthcheck() { - w.alive = true - w.log.Debug("backend is alive, backing off with healthchecks") - w.HealthyCallback() - break - } - w.log.Debug("backend not alive yet") - } -} - func (w *Worker) Reload() { - w.log.WithField("method", "reload").Info("reloading gunicorn") + w.log.WithField("method", "reload").Info("reloading dramatiq") err := w.p.Process.Signal(syscall.SIGHUP) if err != nil { - w.log.WithError(err).Warning("failed to reload gunicorn") + w.log.WithError(err).Warning("failed to reload dramatiq") } } -func (w *Worker) Restart() { - w.log.WithField("method", "restart").Info("restart gunicorn") - if w.pidFile == "" { - w.log.Warning("pidfile is non existent, cannot restart") - return - } - - err := w.p.Process.Signal(syscall.SIGUSR2) - if err != nil { - w.log.WithError(err).Warning("failed to restart gunicorn") - return - } - - newPidFile := fmt.Sprintf("%s.2", w.pidFile) - - // Wait for the new PID file to be created - for range time.NewTicker(1 * time.Second).C { - _, err = os.Stat(newPidFile) - if err == nil || !os.IsNotExist(err) { - break - } - w.log.Debugf("waiting for new gunicorn pidfile to appear at %s", newPidFile) - } - if err != nil { - w.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") - return - } - - newPidB, err := os.ReadFile(newPidFile) - if err != nil { - w.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") - return - } - newPidS := strings.TrimSpace(string(newPidB[:])) - newPid, err := strconv.Atoi(newPidS) - if err != nil { - w.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") - return - } - w.log.Warningf("new gunicorn PID is %d", newPid) - - newProcess, err := utils.FindProcess(newPid) - if newProcess == nil || err != nil { - w.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") - return - } - - // The new process has started, let's gracefully kill the old one - w.log.Warning("killing old gunicorn") - err = w.p.Process.Signal(syscall.SIGTERM) - if err != nil { - w.log.Warning("failed to kill old instance of gunicorn") - } - - w.p.Process = newProcess - // No need to close any files and the .2 pid file is deleted by Gunicorn -} - -func (w *Worker) Kill() { +func (w *Worker) Kill(sig syscall.Signal) { if !w.started { return } var err error if runtime.GOOS == "darwin" { - w.log.WithField("method", "kill").Warning("stopping gunicorn") + w.log.WithField("method", "processKill").Warning("stopping dramatiq") err = w.p.Process.Kill() } else { - w.log.WithField("method", "sigterm").Warning("stopping gunicorn") - err = syscall.Kill(w.p.Process.Pid, syscall.SIGTERM) + w.log.WithField("method", "syscallKill").Warning("stopping dramatiq") + err = syscall.Kill(w.p.Process.Pid, sig) } if err != nil { - w.log.WithError(err).Warning("failed to stop gunicorn") + w.log.WithError(err).Warning("failed to stop dramatiq") } if w.pidFile != "" { err := os.Remove(w.pidFile)