diff --git a/cmd/server/server.go b/cmd/server/server.go index b734a99d4..e81b97c22 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -4,11 +4,15 @@ import ( "fmt" "net/http" "net/url" + "os" + "os/signal" + "syscall" "time" "github.com/getsentry/sentry-go" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "goauthentik.io/internal/common" "goauthentik.io/internal/config" "goauthentik.io/internal/constants" @@ -70,6 +74,21 @@ var rootCmd = &cobra.Command{ l.Info("shutting down gunicorn") g.Kill() }() + + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP, syscall.SIGUSR2) + go func() { + sig := <-c + if sig == syscall.SIGHUP { + log.Info("SIGHUP received, forwarding to gunicorn") + g.Reload() + } + if sig == syscall.SIGUSR2 { + log.Info("SIGUSR2 received, restarting gunicorn") + g.Restart() + } + }() + ws := web.NewWebServer(g) g.HealthyCallback = func() { if !config.Get().Outposts.DisableEmbeddedOutpost { @@ -92,8 +111,24 @@ func attemptStartBackend(g *gounicorn.GoUnicorn) { if !running { return } + g.Kill() + log.WithField("logger", "authentik.router").Info("starting gunicorn") err := g.Start() - log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting") + if err != nil { + log.WithField("logger", "authentik.router").WithError(err).Error("gunicorn failed to start, restarting") + continue + } + failedChecks := 0 + for range time.Tick(30 * time.Second) { + if !g.IsRunning() { + log.WithField("logger", "authentik.router").Warningf("gunicorn process failed healthcheck %d times", failedChecks) + failedChecks += 1 + } + if failedChecks >= 3 { + log.WithField("logger", "authentik.router").WithError(err).Error("gunicorn process failed healthcheck three times, restarting") + break + } + } } } diff --git a/internal/gounicorn/gounicorn.go b/internal/gounicorn/gounicorn.go index 5cf65a733..534e5d40e 100644 --- a/internal/gounicorn/gounicorn.go +++ b/internal/gounicorn/gounicorn.go @@ -1,15 +1,21 @@ package gounicorn import ( + "fmt" + "io/ioutil" "net/http" "os" "os/exec" "runtime" + "strconv" + "strings" "syscall" "time" log "github.com/sirupsen/logrus" + "goauthentik.io/internal/config" + "goauthentik.io/internal/utils" "goauthentik.io/internal/utils/web" ) @@ -18,6 +24,7 @@ type GoUnicorn struct { log *log.Entry p *exec.Cmd + pidFile *string started bool killed bool alive bool @@ -27,6 +34,7 @@ func New() *GoUnicorn { logger := log.WithField("logger", "authentik.router.unicorn") g := &GoUnicorn{ log: logger, + pidFile: nil, started: false, killed: false, alive: false, @@ -37,8 +45,13 @@ func New() *GoUnicorn { } func (g *GoUnicorn) initCmd() { + pidFile, _ := os.CreateTemp("", "authentik-gunicorn.*.pid") + g.pidFile = func() *string { s := pidFile.Name(); return &s }() command := "gunicorn" args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} + if g.pidFile != nil { + args = append(args, "--pid", *g.pidFile) + } if config.Get().Debug { command = "./manage.py" args = []string{"runserver"} @@ -55,16 +68,13 @@ func (g *GoUnicorn) IsRunning() bool { } func (g *GoUnicorn) Start() error { - if g.killed { - g.log.Debug("Not restarting gunicorn since we're shutdown") - return nil - } if g.started { g.initCmd() } + g.killed = false g.started = true go g.healthcheck() - return g.p.Run() + return g.p.Start() } func (g *GoUnicorn) healthcheck() { @@ -96,8 +106,77 @@ func (g *GoUnicorn) healthcheck() { } } +func (g *GoUnicorn) Reload() { + g.log.WithField("method", "reload").Info("reloading gunicorn") + err := g.p.Process.Signal(syscall.SIGHUP) + if err != nil { + g.log.WithError(err).Warning("failed to reload gunicorn") + } +} + +func (g *GoUnicorn) Restart() { + g.log.WithField("method", "restart").Info("restart gunicorn") + if g.pidFile == nil { + g.log.Warning("pidfile is non existent, cannot restart") + return + } + + err := g.p.Process.Signal(syscall.SIGUSR2) + if err != nil { + g.log.WithError(err).Warning("failed to restart gunicorn") + return + } + + newPidFile := fmt.Sprintf("%s.2", *g.pidFile) + + // Wait for the new PID file to be created + for range time.Tick(1 * time.Second) { + _, err = os.Stat(newPidFile) + if err == nil || !os.IsNotExist(err) { + break + } + g.log.Debugf("waiting for new gunicorn pidfile to appear at %s", newPidFile) + } + if err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + + newPidB, err := ioutil.ReadFile(newPidFile) + if err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + newPidS := strings.TrimSpace(string(newPidB[:])) + newPid, err := strconv.Atoi(newPidS) + if err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + g.log.Warningf("new gunicorn PID is %d", newPid) + + newProcess, err := utils.FindProcess(newPid) + if newProcess == nil || err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + + // The new process has started, let's gracefully kill the old one + g.log.Warningf("killing old gunicorn") + err = g.p.Process.Signal(syscall.SIGTERM) + if err != nil { + g.log.Warning("failed to kill old instance of gunicorn") + } + + g.p.Process = newProcess + + // No need to close any files and the .2 pid file is deleted by Gunicorn +} + func (g *GoUnicorn) Kill() { - g.killed = true + if !g.started { + return + } var err error if runtime.GOOS == "darwin" { g.log.WithField("method", "kill").Warning("stopping gunicorn") @@ -109,4 +188,8 @@ func (g *GoUnicorn) Kill() { if err != nil { g.log.WithError(err).Warning("failed to stop gunicorn") } + if g.pidFile != nil { + os.Remove(*g.pidFile) + } + g.killed = true } diff --git a/internal/utils/process.go b/internal/utils/process.go new file mode 100644 index 000000000..c53b4fc0e --- /dev/null +++ b/internal/utils/process.go @@ -0,0 +1,38 @@ +package utils + +import ( + "fmt" + "os" + "syscall" +) + +func FindProcess(pid int) (*os.Process, error) { + if pid <= 0 { + return nil, fmt.Errorf("invalid pid %v", pid) + } + // The error doesn't mean anything on Unix systems, let's just check manually + // that the new gunicorn master has properly started + // https://github.com/golang/go/issues/34396 + proc, err := os.FindProcess(int(pid)) + if err != nil { + return nil, err + } + err = proc.Signal(syscall.Signal(0)) + if err == nil { + return proc, nil + } + if err.Error() == "os: process already finished" { + return nil, nil + } + errno, ok := err.(syscall.Errno) + if !ok { + return nil, err + } + switch errno { + case syscall.ESRCH: + return nil, nil + case syscall.EPERM: + return proc, nil + } + return nil, err +}