providers/ldap: add unbind flow execution (#4484)

add unbind flow execution

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-01-23 20:36:30 +01:00 committed by GitHub
parent b2d272bf6f
commit a9b32e2f97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 267 additions and 132 deletions

View File

@ -1,8 +1,11 @@
package ak package ak
import "context"
type Outpost interface { type Outpost interface {
Start() error Start() error
Stop() error
Refresh() error Refresh() error
TimerFlowCacheExpiry() TimerFlowCacheExpiry(context.Context)
Type() string Type() string
} }

View File

@ -1,15 +1,16 @@
package ak package ak
import ( import (
"context"
"time" "time"
) )
func (a *APIController) startPeriodicalTasks() { func (a *APIController) startPeriodicalTasks() {
go a.Server.TimerFlowCacheExpiry() ctx, canc := context.WithCancel(context.Background())
go func() { defer canc()
for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) { go a.Server.TimerFlowCacheExpiry(ctx)
a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks") for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) {
a.Server.TimerFlowCacheExpiry() a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks")
} a.Server.TimerFlowCacheExpiry(ctx)
}() }
} }

View File

@ -143,6 +143,10 @@ func (fe *FlowExecutor) GetSession() *http.Cookie {
return fe.session return fe.session
} }
func (fe *FlowExecutor) SetSession(s *http.Cookie) {
fe.session = s
}
// WarmUp Ensure authentik's flow cache is warmed up // WarmUp Ensure authentik's flow cache is warmed up
func (fe *FlowExecutor) WarmUp() error { func (fe *FlowExecutor) WarmUp() error {
gcsp := sentry.StartSpan(fe.Context, "authentik.outposts.flow_executor.get_challenge") gcsp := sentry.StartSpan(fe.Context, "authentik.outposts.flow_executor.get_challenge")

View File

@ -1,9 +1,14 @@
package bind package bind
import "github.com/nmcclain/ldap" import (
"context"
"github.com/nmcclain/ldap"
)
type Binder interface { type Binder interface {
GetUsername(string) (string, error) GetUsername(string) (string, error)
Bind(username string, req *Request) (ldap.LDAPResultCode, error) Bind(username string, req *Request) (ldap.LDAPResultCode, error)
TimerFlowCacheExpiry() Unbind(username string, req *Request) (ldap.LDAPResultCode, error)
TimerFlowCacheExpiry(context.Context)
} }

View File

@ -0,0 +1,98 @@
package direct
import (
"context"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/metrics"
)
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
fe := flow.NewFlowExecutor(req.Context(), db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"bindDN": req.BindDN,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Answers[flow.StageIdentification] = username
fe.Answers[flow.StagePassword] = req.BindPW
passed, err := fe.Execute()
flags := flags.UserFlags{
Session: fe.GetSession(),
}
db.si.SetFlags(req.BindDN, &flags)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "flow_error",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to execute flow")
return ldap.LDAPResultInvalidCredentials, nil
}
if !passed {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "invalid_credentials",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().Info("Invalid credentials")
return ldap.LDAPResultInvalidCredentials, nil
}
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
if !access {
req.Log().Info("Access denied for user")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_denied",
"app": db.si.GetAppSlug(),
}).Inc()
return ldap.LDAPResultInsufficientAccessRights, nil
}
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_check_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to check access")
return ldap.LDAPResultOperationsError, nil
}
req.Log().Info("User has access")
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
// Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "user_info_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to get user info")
return ldap.LDAPResultOperationsError, nil
}
cs := db.SearchAccessCheck(userInfo.User)
flags.UserPk = userInfo.User.Pk
flags.CanSearch = cs != nil
db.si.SetFlags(req.BindDN, &flags)
if flags.CanSearch {
req.Log().WithField("group", cs).Info("Allowed access to search")
}
uisp.Finish()
return ldap.LDAPResultSuccess, nil
}

View File

@ -5,16 +5,10 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/getsentry/sentry-go"
goldap "github.com/go-ldap/ldap/v3" goldap "github.com/go-ldap/ldap/v3"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/flow" "goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils" "goauthentik.io/internal/outpost/ldap/utils"
) )
@ -53,90 +47,6 @@ func (db *DirectBinder) GetUsername(dn string) (string, error) {
return "", errors.New("failed to find cn") return "", errors.New("failed to find cn")
} }
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
fe := flow.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"bindDN": req.BindDN,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Answers[flow.StageIdentification] = username
fe.Answers[flow.StagePassword] = req.BindPW
passed, err := fe.Execute()
flags := flags.UserFlags{
Session: fe.GetSession(),
}
db.si.SetFlags(req.BindDN, flags)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "flow_error",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to execute flow")
return ldap.LDAPResultInvalidCredentials, nil
}
if !passed {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "invalid_credentials",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().Info("Invalid credentials")
return ldap.LDAPResultInvalidCredentials, nil
}
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
if !access {
req.Log().Info("Access denied for user")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_denied",
"app": db.si.GetAppSlug(),
}).Inc()
return ldap.LDAPResultInsufficientAccessRights, nil
}
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_check_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to check access")
return ldap.LDAPResultOperationsError, nil
}
req.Log().Info("User has access")
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
// Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "user_info_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to get user info")
return ldap.LDAPResultOperationsError, nil
}
cs := db.SearchAccessCheck(userInfo.User)
flags.UserPk = userInfo.User.Pk
flags.CanSearch = cs != nil
db.si.SetFlags(req.BindDN, flags)
if flags.CanSearch {
req.Log().WithField("group", cs).Info("Allowed access to search")
}
uisp.Finish()
return ldap.LDAPResultSuccess, nil
}
// SearchAccessCheck Check if the current user is allowed to search // SearchAccessCheck Check if the current user is allowed to search
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string { func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
for _, group := range user.Groups { for _, group := range user.Groups {
@ -153,8 +63,8 @@ func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
return nil return nil
} }
func (db *DirectBinder) TimerFlowCacheExpiry() { func (db *DirectBinder) TimerFlowCacheExpiry(ctx context.Context) {
fe := flow.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{}) fe := flow.NewFlowExecutor(ctx, db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
fe.Params.Add("goauthentik.io/outpost/ldap", "true") fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true") fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")

View File

@ -0,0 +1,29 @@
package direct
import (
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/bind"
)
func (db *DirectBinder) Unbind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
flags := db.si.GetFlags(req.BindDN)
if flags == nil || flags.Session == nil {
return ldap.LDAPResultSuccess, nil
}
fe := flow.NewFlowExecutor(req.Context(), db.si.GetInvalidationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"boundDN": req.BindDN,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.SetSession(flags.Session)
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
_, err := fe.Execute()
if err != nil {
db.log.WithError(err).Warning("failed to logout user")
}
db.si.SetFlags(req.BindDN, nil)
return ldap.LDAPResultSuccess, nil
}

View File

@ -27,10 +27,11 @@ type ProviderInstance struct {
searcher search.Searcher searcher search.Searcher
binder bind.Binder binder bind.Binder
appSlug string appSlug string
flowSlug string authenticationFlowSlug string
s *LDAPServer invalidationFlowSlug string
log *log.Entry s *LDAPServer
log *log.Entry
tlsServerName *string tlsServerName *string
cert *tls.Certificate cert *tls.Certificate
@ -79,9 +80,13 @@ func (pi *ProviderInstance) GetFlags(dn string) *flags.UserFlags {
return flags return flags
} }
func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) { func (pi *ProviderInstance) SetFlags(dn string, flag *flags.UserFlags) {
pi.boundUsersMutex.Lock() pi.boundUsersMutex.Lock()
pi.boundUsers[dn] = &flag if flag == nil {
delete(pi.boundUsers, dn)
} else {
pi.boundUsers[dn] = flag
}
pi.boundUsersMutex.Unlock() pi.boundUsersMutex.Unlock()
} }
@ -89,8 +94,12 @@ func (pi *ProviderInstance) GetAppSlug() string {
return pi.appSlug return pi.appSlug
} }
func (pi *ProviderInstance) GetFlowSlug() string { func (pi *ProviderInstance) GetAuthenticationFlowSlug() string {
return pi.flowSlug return pi.authenticationFlowSlug
}
func (pi *ProviderInstance) GetInvalidationFlowSlug() string {
return pi.invalidationFlowSlug
} }
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID { func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {

View File

@ -1,6 +1,7 @@
package ldap package ldap
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"sync" "sync"
@ -40,6 +41,7 @@ func NewServer(ac *ak.APIController) *LDAPServer {
} }
ls.defaultCert = &defaultCert ls.defaultCert = &defaultCert
s.BindFunc("", ls) s.BindFunc("", ls)
s.UnbindFunc("", ls)
s.SearchFunc("", ls) s.SearchFunc("", ls)
return ls return ls
} }
@ -92,9 +94,13 @@ func (ls *LDAPServer) Start() error {
return nil return nil
} }
func (ls *LDAPServer) TimerFlowCacheExpiry() { func (ls *LDAPServer) Stop() error {
return nil
}
func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) {
for _, p := range ls.providers { for _, p := range ls.providers {
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache") ls.log.WithField("flow", p.authenticationFlowSlug).Debug("Pre-heating flow cache")
p.binder.TimerFlowCacheExpiry() p.binder.TimerFlowCacheExpiry(ctx)
} }
} }

View File

@ -28,6 +28,16 @@ func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
return nil return nil
} }
func (ls *LDAPServer) getInvalidationFlow() string {
req, _, err := ls.ac.Client.CoreApi.CoreTenantsCurrentRetrieve(context.Background()).Execute()
if err != nil {
ls.log.WithError(err).Warning("failed to fetch tenant config")
return ""
}
flow := req.GetFlowInvalidation()
return flow
}
func (ls *LDAPServer) Refresh() error { func (ls *LDAPServer) Refresh() error {
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
if err != nil { if err != nil {
@ -37,6 +47,7 @@ func (ls *LDAPServer) Refresh() error {
return errors.New("no ldap provider defined") return errors.New("no ldap provider defined")
} }
providers := make([]*ProviderInstance, len(outposts.Results)) providers := make([]*ProviderInstance, len(outposts.Results))
invalidationFlow := ls.getInvalidationFlow()
for idx, provider := range outposts.Results { for idx, provider := range outposts.Results {
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn)) userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
@ -53,22 +64,23 @@ func (ls *LDAPServer) Refresh() error {
} }
providers[idx] = &ProviderInstance{ providers[idx] = &ProviderInstance{
BaseDN: *provider.BaseDn, BaseDN: *provider.BaseDn,
VirtualGroupDN: virtualGroupDN, VirtualGroupDN: virtualGroupDN,
GroupDN: groupDN, GroupDN: groupDN,
UserDN: userDN, UserDN: userDN,
appSlug: provider.ApplicationSlug, appSlug: provider.ApplicationSlug,
flowSlug: provider.BindFlowSlug, authenticationFlowSlug: provider.BindFlowSlug,
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, invalidationFlowSlug: invalidationFlow,
boundUsersMutex: sync.RWMutex{}, searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
boundUsers: users, boundUsersMutex: sync.RWMutex{},
s: ls, boundUsers: users,
log: logger, s: ls,
tlsServerName: provider.TlsServerName, log: logger,
uidStartNumber: *provider.UidStartNumber, tlsServerName: provider.TlsServerName,
gidStartNumber: *provider.GidStartNumber, uidStartNumber: *provider.UidStartNumber,
outpostName: ls.ac.Outpost.Name, gidStartNumber: *provider.GidStartNumber,
outpostPk: provider.Pk, outpostName: ls.ac.Outpost.Name,
outpostPk: provider.Pk,
} }
if kp := provider.Certificate.Get(); kp != nil { if kp := provider.Certificate.Get(); kp != nil {
err := ls.cs.AddKeypair(*kp) err := ls.cs.AddKeypair(*kp)

View File

@ -11,7 +11,8 @@ type LDAPServerInstance interface {
GetAPIClient() *api.APIClient GetAPIClient() *api.APIClient
GetOutpostName() string GetOutpostName() string
GetFlowSlug() string GetAuthenticationFlowSlug() string
GetInvalidationFlowSlug() string
GetAppSlug() string GetAppSlug() string
GetSearchAllowedGroups() []*strfmt.UUID GetSearchAllowedGroups() []*strfmt.UUID
@ -32,7 +33,7 @@ type LDAPServerInstance interface {
UsersForGroup(api.Group) []string UsersForGroup(api.Group) []string
GetFlags(dn string) *flags.UserFlags GetFlags(dn string) *flags.UserFlags
SetFlags(dn string, flags flags.UserFlags) SetFlags(dn string, flags *flags.UserFlags)
GetBaseEntry() *ldap.Entry GetBaseEntry() *ldap.Entry
GetNeededObjects(int, string, string) (bool, bool) GetNeededObjects(int, string, string) (bool, bool)

View File

@ -0,0 +1,53 @@
package ldap
import (
"net"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/metrics"
)
func (ls *LDAPServer) Unbind(boundDN string, conn net.Conn) (ldap.LDAPResultCode, error) {
req, span := bind.NewRequest(boundDN, "", conn)
selectedApp := ""
defer func() {
span.Finish()
metrics.Requests.With(prometheus.Labels{
"outpost_name": ls.ac.Outpost.Name,
"type": "unbind",
"app": selectedApp,
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Unbind request")
}()
defer func() {
err := recover()
if err == nil {
return
}
log.WithError(err.(error)).Error("recover in bind request")
sentry.CaptureException(err.(error))
}()
for _, instance := range ls.providers {
username, err := instance.binder.GetUsername(boundDN)
if err == nil {
selectedApp = instance.GetAppSlug()
return instance.binder.Unbind(username, req)
} else {
req.Log().WithError(err).Debug("Username not for instance")
}
}
req.Log().WithField("request", "unbind").Warning("No provider found for request")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ls.ac.Outpost.Name,
"type": "unbind",
"reason": "no_provider",
"app": "",
}).Inc()
return ldap.LDAPResultOperationsError, nil
}

View File

@ -81,7 +81,7 @@ func (ps *ProxyServer) Type() string {
return "proxy" return "proxy"
} }
func (ps *ProxyServer) TimerFlowCacheExpiry() {} func (ps *ProxyServer) TimerFlowCacheExpiry(context.Context) {}
func (ps *ProxyServer) GetCertificate(serverName string) *tls.Certificate { func (ps *ProxyServer) GetCertificate(serverName string) *tls.Certificate {
app, ok := ps.apps[serverName] app, ok := ps.apps[serverName]
@ -163,6 +163,10 @@ func (ps *ProxyServer) Start() error {
return nil return nil
} }
func (ps *ProxyServer) Stop() error {
return nil
}
func (ps *ProxyServer) serve(listener net.Listener) { func (ps *ProxyServer) serve(listener net.Listener) {
srv := &http.Server{Handler: ps.mux} srv := &http.Server{Handler: ps.mux}