providers/ldap: add unbind flow execution (#4484)
add unbind flow execution Signed-off-by: Jens Langhammer <jens@goauthentik.io> Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
b2d272bf6f
commit
a9b32e2f97
|
@ -1,8 +1,11 @@
|
||||||
package ak
|
package ak
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
type Outpost interface {
|
type Outpost interface {
|
||||||
Start() error
|
Start() error
|
||||||
|
Stop() error
|
||||||
Refresh() error
|
Refresh() error
|
||||||
TimerFlowCacheExpiry()
|
TimerFlowCacheExpiry(context.Context)
|
||||||
Type() string
|
Type() string
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
package ak
|
package ak
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *APIController) startPeriodicalTasks() {
|
func (a *APIController) startPeriodicalTasks() {
|
||||||
go a.Server.TimerFlowCacheExpiry()
|
ctx, canc := context.WithCancel(context.Background())
|
||||||
go func() {
|
defer canc()
|
||||||
for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) {
|
go a.Server.TimerFlowCacheExpiry(ctx)
|
||||||
a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks")
|
for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) {
|
||||||
a.Server.TimerFlowCacheExpiry()
|
a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks")
|
||||||
}
|
a.Server.TimerFlowCacheExpiry(ctx)
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -143,6 +143,10 @@ func (fe *FlowExecutor) GetSession() *http.Cookie {
|
||||||
return fe.session
|
return fe.session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fe *FlowExecutor) SetSession(s *http.Cookie) {
|
||||||
|
fe.session = s
|
||||||
|
}
|
||||||
|
|
||||||
// WarmUp Ensure authentik's flow cache is warmed up
|
// WarmUp Ensure authentik's flow cache is warmed up
|
||||||
func (fe *FlowExecutor) WarmUp() error {
|
func (fe *FlowExecutor) WarmUp() error {
|
||||||
gcsp := sentry.StartSpan(fe.Context, "authentik.outposts.flow_executor.get_challenge")
|
gcsp := sentry.StartSpan(fe.Context, "authentik.outposts.flow_executor.get_challenge")
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
package bind
|
package bind
|
||||||
|
|
||||||
import "github.com/nmcclain/ldap"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/nmcclain/ldap"
|
||||||
|
)
|
||||||
|
|
||||||
type Binder interface {
|
type Binder interface {
|
||||||
GetUsername(string) (string, error)
|
GetUsername(string) (string, error)
|
||||||
Bind(username string, req *Request) (ldap.LDAPResultCode, error)
|
Bind(username string, req *Request) (ldap.LDAPResultCode, error)
|
||||||
TimerFlowCacheExpiry()
|
Unbind(username string, req *Request) (ldap.LDAPResultCode, error)
|
||||||
|
TimerFlowCacheExpiry(context.Context)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
package direct
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/nmcclain/ldap"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/flow"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/bind"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/flags"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
|
||||||
|
fe := flow.NewFlowExecutor(req.Context(), db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
|
||||||
|
"bindDN": req.BindDN,
|
||||||
|
"client": req.RemoteAddr(),
|
||||||
|
"requestId": req.ID(),
|
||||||
|
})
|
||||||
|
fe.DelegateClientIP(req.RemoteAddr())
|
||||||
|
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
||||||
|
|
||||||
|
fe.Answers[flow.StageIdentification] = username
|
||||||
|
fe.Answers[flow.StagePassword] = req.BindPW
|
||||||
|
|
||||||
|
passed, err := fe.Execute()
|
||||||
|
flags := flags.UserFlags{
|
||||||
|
Session: fe.GetSession(),
|
||||||
|
}
|
||||||
|
db.si.SetFlags(req.BindDN, &flags)
|
||||||
|
if err != nil {
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": db.si.GetOutpostName(),
|
||||||
|
"type": "bind",
|
||||||
|
"reason": "flow_error",
|
||||||
|
"app": db.si.GetAppSlug(),
|
||||||
|
}).Inc()
|
||||||
|
req.Log().WithError(err).Warning("failed to execute flow")
|
||||||
|
return ldap.LDAPResultInvalidCredentials, nil
|
||||||
|
}
|
||||||
|
if !passed {
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": db.si.GetOutpostName(),
|
||||||
|
"type": "bind",
|
||||||
|
"reason": "invalid_credentials",
|
||||||
|
"app": db.si.GetAppSlug(),
|
||||||
|
}).Inc()
|
||||||
|
req.Log().Info("Invalid credentials")
|
||||||
|
return ldap.LDAPResultInvalidCredentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
|
||||||
|
if !access {
|
||||||
|
req.Log().Info("Access denied for user")
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": db.si.GetOutpostName(),
|
||||||
|
"type": "bind",
|
||||||
|
"reason": "access_denied",
|
||||||
|
"app": db.si.GetAppSlug(),
|
||||||
|
}).Inc()
|
||||||
|
return ldap.LDAPResultInsufficientAccessRights, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": db.si.GetOutpostName(),
|
||||||
|
"type": "bind",
|
||||||
|
"reason": "access_check_fail",
|
||||||
|
"app": db.si.GetAppSlug(),
|
||||||
|
}).Inc()
|
||||||
|
req.Log().WithError(err).Warning("failed to check access")
|
||||||
|
return ldap.LDAPResultOperationsError, nil
|
||||||
|
}
|
||||||
|
req.Log().Info("User has access")
|
||||||
|
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
|
||||||
|
// Get user info to store in context
|
||||||
|
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
|
||||||
|
if err != nil {
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": db.si.GetOutpostName(),
|
||||||
|
"type": "bind",
|
||||||
|
"reason": "user_info_fail",
|
||||||
|
"app": db.si.GetAppSlug(),
|
||||||
|
}).Inc()
|
||||||
|
req.Log().WithError(err).Warning("failed to get user info")
|
||||||
|
return ldap.LDAPResultOperationsError, nil
|
||||||
|
}
|
||||||
|
cs := db.SearchAccessCheck(userInfo.User)
|
||||||
|
flags.UserPk = userInfo.User.Pk
|
||||||
|
flags.CanSearch = cs != nil
|
||||||
|
db.si.SetFlags(req.BindDN, &flags)
|
||||||
|
if flags.CanSearch {
|
||||||
|
req.Log().WithField("group", cs).Info("Allowed access to search")
|
||||||
|
}
|
||||||
|
uisp.Finish()
|
||||||
|
return ldap.LDAPResultSuccess, nil
|
||||||
|
}
|
|
@ -5,16 +5,10 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
|
||||||
goldap "github.com/go-ldap/ldap/v3"
|
goldap "github.com/go-ldap/ldap/v3"
|
||||||
"github.com/nmcclain/ldap"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/api/v3"
|
"goauthentik.io/api/v3"
|
||||||
"goauthentik.io/internal/outpost/flow"
|
"goauthentik.io/internal/outpost/flow"
|
||||||
"goauthentik.io/internal/outpost/ldap/bind"
|
|
||||||
"goauthentik.io/internal/outpost/ldap/flags"
|
|
||||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
|
||||||
"goauthentik.io/internal/outpost/ldap/server"
|
"goauthentik.io/internal/outpost/ldap/server"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils"
|
"goauthentik.io/internal/outpost/ldap/utils"
|
||||||
)
|
)
|
||||||
|
@ -53,90 +47,6 @@ func (db *DirectBinder) GetUsername(dn string) (string, error) {
|
||||||
return "", errors.New("failed to find cn")
|
return "", errors.New("failed to find cn")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
|
|
||||||
fe := flow.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
|
|
||||||
"bindDN": req.BindDN,
|
|
||||||
"client": req.RemoteAddr(),
|
|
||||||
"requestId": req.ID(),
|
|
||||||
})
|
|
||||||
fe.DelegateClientIP(req.RemoteAddr())
|
|
||||||
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
|
||||||
|
|
||||||
fe.Answers[flow.StageIdentification] = username
|
|
||||||
fe.Answers[flow.StagePassword] = req.BindPW
|
|
||||||
|
|
||||||
passed, err := fe.Execute()
|
|
||||||
flags := flags.UserFlags{
|
|
||||||
Session: fe.GetSession(),
|
|
||||||
}
|
|
||||||
db.si.SetFlags(req.BindDN, flags)
|
|
||||||
if err != nil {
|
|
||||||
metrics.RequestsRejected.With(prometheus.Labels{
|
|
||||||
"outpost_name": db.si.GetOutpostName(),
|
|
||||||
"type": "bind",
|
|
||||||
"reason": "flow_error",
|
|
||||||
"app": db.si.GetAppSlug(),
|
|
||||||
}).Inc()
|
|
||||||
req.Log().WithError(err).Warning("failed to execute flow")
|
|
||||||
return ldap.LDAPResultInvalidCredentials, nil
|
|
||||||
}
|
|
||||||
if !passed {
|
|
||||||
metrics.RequestsRejected.With(prometheus.Labels{
|
|
||||||
"outpost_name": db.si.GetOutpostName(),
|
|
||||||
"type": "bind",
|
|
||||||
"reason": "invalid_credentials",
|
|
||||||
"app": db.si.GetAppSlug(),
|
|
||||||
}).Inc()
|
|
||||||
req.Log().Info("Invalid credentials")
|
|
||||||
return ldap.LDAPResultInvalidCredentials, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
|
|
||||||
if !access {
|
|
||||||
req.Log().Info("Access denied for user")
|
|
||||||
metrics.RequestsRejected.With(prometheus.Labels{
|
|
||||||
"outpost_name": db.si.GetOutpostName(),
|
|
||||||
"type": "bind",
|
|
||||||
"reason": "access_denied",
|
|
||||||
"app": db.si.GetAppSlug(),
|
|
||||||
}).Inc()
|
|
||||||
return ldap.LDAPResultInsufficientAccessRights, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
metrics.RequestsRejected.With(prometheus.Labels{
|
|
||||||
"outpost_name": db.si.GetOutpostName(),
|
|
||||||
"type": "bind",
|
|
||||||
"reason": "access_check_fail",
|
|
||||||
"app": db.si.GetAppSlug(),
|
|
||||||
}).Inc()
|
|
||||||
req.Log().WithError(err).Warning("failed to check access")
|
|
||||||
return ldap.LDAPResultOperationsError, nil
|
|
||||||
}
|
|
||||||
req.Log().Info("User has access")
|
|
||||||
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
|
|
||||||
// Get user info to store in context
|
|
||||||
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
|
|
||||||
if err != nil {
|
|
||||||
metrics.RequestsRejected.With(prometheus.Labels{
|
|
||||||
"outpost_name": db.si.GetOutpostName(),
|
|
||||||
"type": "bind",
|
|
||||||
"reason": "user_info_fail",
|
|
||||||
"app": db.si.GetAppSlug(),
|
|
||||||
}).Inc()
|
|
||||||
req.Log().WithError(err).Warning("failed to get user info")
|
|
||||||
return ldap.LDAPResultOperationsError, nil
|
|
||||||
}
|
|
||||||
cs := db.SearchAccessCheck(userInfo.User)
|
|
||||||
flags.UserPk = userInfo.User.Pk
|
|
||||||
flags.CanSearch = cs != nil
|
|
||||||
db.si.SetFlags(req.BindDN, flags)
|
|
||||||
if flags.CanSearch {
|
|
||||||
req.Log().WithField("group", cs).Info("Allowed access to search")
|
|
||||||
}
|
|
||||||
uisp.Finish()
|
|
||||||
return ldap.LDAPResultSuccess, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchAccessCheck Check if the current user is allowed to search
|
// SearchAccessCheck Check if the current user is allowed to search
|
||||||
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
|
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
|
||||||
for _, group := range user.Groups {
|
for _, group := range user.Groups {
|
||||||
|
@ -153,8 +63,8 @@ func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DirectBinder) TimerFlowCacheExpiry() {
|
func (db *DirectBinder) TimerFlowCacheExpiry(ctx context.Context) {
|
||||||
fe := flow.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
|
fe := flow.NewFlowExecutor(ctx, db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
|
||||||
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
||||||
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")
|
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package direct
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nmcclain/ldap"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/flow"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/bind"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DirectBinder) Unbind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
|
||||||
|
flags := db.si.GetFlags(req.BindDN)
|
||||||
|
if flags == nil || flags.Session == nil {
|
||||||
|
return ldap.LDAPResultSuccess, nil
|
||||||
|
}
|
||||||
|
fe := flow.NewFlowExecutor(req.Context(), db.si.GetInvalidationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
|
||||||
|
"boundDN": req.BindDN,
|
||||||
|
"client": req.RemoteAddr(),
|
||||||
|
"requestId": req.ID(),
|
||||||
|
})
|
||||||
|
fe.SetSession(flags.Session)
|
||||||
|
fe.DelegateClientIP(req.RemoteAddr())
|
||||||
|
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
|
||||||
|
_, err := fe.Execute()
|
||||||
|
if err != nil {
|
||||||
|
db.log.WithError(err).Warning("failed to logout user")
|
||||||
|
}
|
||||||
|
db.si.SetFlags(req.BindDN, nil)
|
||||||
|
return ldap.LDAPResultSuccess, nil
|
||||||
|
}
|
|
@ -27,10 +27,11 @@ type ProviderInstance struct {
|
||||||
searcher search.Searcher
|
searcher search.Searcher
|
||||||
binder bind.Binder
|
binder bind.Binder
|
||||||
|
|
||||||
appSlug string
|
appSlug string
|
||||||
flowSlug string
|
authenticationFlowSlug string
|
||||||
s *LDAPServer
|
invalidationFlowSlug string
|
||||||
log *log.Entry
|
s *LDAPServer
|
||||||
|
log *log.Entry
|
||||||
|
|
||||||
tlsServerName *string
|
tlsServerName *string
|
||||||
cert *tls.Certificate
|
cert *tls.Certificate
|
||||||
|
@ -79,9 +80,13 @@ func (pi *ProviderInstance) GetFlags(dn string) *flags.UserFlags {
|
||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) {
|
func (pi *ProviderInstance) SetFlags(dn string, flag *flags.UserFlags) {
|
||||||
pi.boundUsersMutex.Lock()
|
pi.boundUsersMutex.Lock()
|
||||||
pi.boundUsers[dn] = &flag
|
if flag == nil {
|
||||||
|
delete(pi.boundUsers, dn)
|
||||||
|
} else {
|
||||||
|
pi.boundUsers[dn] = flag
|
||||||
|
}
|
||||||
pi.boundUsersMutex.Unlock()
|
pi.boundUsersMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,8 +94,12 @@ func (pi *ProviderInstance) GetAppSlug() string {
|
||||||
return pi.appSlug
|
return pi.appSlug
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pi *ProviderInstance) GetFlowSlug() string {
|
func (pi *ProviderInstance) GetAuthenticationFlowSlug() string {
|
||||||
return pi.flowSlug
|
return pi.authenticationFlowSlug
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pi *ProviderInstance) GetInvalidationFlowSlug() string {
|
||||||
|
return pi.invalidationFlowSlug
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {
|
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package ldap
|
package ldap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -40,6 +41,7 @@ func NewServer(ac *ak.APIController) *LDAPServer {
|
||||||
}
|
}
|
||||||
ls.defaultCert = &defaultCert
|
ls.defaultCert = &defaultCert
|
||||||
s.BindFunc("", ls)
|
s.BindFunc("", ls)
|
||||||
|
s.UnbindFunc("", ls)
|
||||||
s.SearchFunc("", ls)
|
s.SearchFunc("", ls)
|
||||||
return ls
|
return ls
|
||||||
}
|
}
|
||||||
|
@ -92,9 +94,13 @@ func (ls *LDAPServer) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ls *LDAPServer) TimerFlowCacheExpiry() {
|
func (ls *LDAPServer) Stop() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) {
|
||||||
for _, p := range ls.providers {
|
for _, p := range ls.providers {
|
||||||
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
|
ls.log.WithField("flow", p.authenticationFlowSlug).Debug("Pre-heating flow cache")
|
||||||
p.binder.TimerFlowCacheExpiry()
|
p.binder.TimerFlowCacheExpiry(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,16 @@ func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ls *LDAPServer) getInvalidationFlow() string {
|
||||||
|
req, _, err := ls.ac.Client.CoreApi.CoreTenantsCurrentRetrieve(context.Background()).Execute()
|
||||||
|
if err != nil {
|
||||||
|
ls.log.WithError(err).Warning("failed to fetch tenant config")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
flow := req.GetFlowInvalidation()
|
||||||
|
return flow
|
||||||
|
}
|
||||||
|
|
||||||
func (ls *LDAPServer) Refresh() error {
|
func (ls *LDAPServer) Refresh() error {
|
||||||
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
|
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -37,6 +47,7 @@ func (ls *LDAPServer) Refresh() error {
|
||||||
return errors.New("no ldap provider defined")
|
return errors.New("no ldap provider defined")
|
||||||
}
|
}
|
||||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
providers := make([]*ProviderInstance, len(outposts.Results))
|
||||||
|
invalidationFlow := ls.getInvalidationFlow()
|
||||||
for idx, provider := range outposts.Results {
|
for idx, provider := range outposts.Results {
|
||||||
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
|
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
|
||||||
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
||||||
|
@ -53,22 +64,23 @@ func (ls *LDAPServer) Refresh() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
providers[idx] = &ProviderInstance{
|
providers[idx] = &ProviderInstance{
|
||||||
BaseDN: *provider.BaseDn,
|
BaseDN: *provider.BaseDn,
|
||||||
VirtualGroupDN: virtualGroupDN,
|
VirtualGroupDN: virtualGroupDN,
|
||||||
GroupDN: groupDN,
|
GroupDN: groupDN,
|
||||||
UserDN: userDN,
|
UserDN: userDN,
|
||||||
appSlug: provider.ApplicationSlug,
|
appSlug: provider.ApplicationSlug,
|
||||||
flowSlug: provider.BindFlowSlug,
|
authenticationFlowSlug: provider.BindFlowSlug,
|
||||||
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
|
invalidationFlowSlug: invalidationFlow,
|
||||||
boundUsersMutex: sync.RWMutex{},
|
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
|
||||||
boundUsers: users,
|
boundUsersMutex: sync.RWMutex{},
|
||||||
s: ls,
|
boundUsers: users,
|
||||||
log: logger,
|
s: ls,
|
||||||
tlsServerName: provider.TlsServerName,
|
log: logger,
|
||||||
uidStartNumber: *provider.UidStartNumber,
|
tlsServerName: provider.TlsServerName,
|
||||||
gidStartNumber: *provider.GidStartNumber,
|
uidStartNumber: *provider.UidStartNumber,
|
||||||
outpostName: ls.ac.Outpost.Name,
|
gidStartNumber: *provider.GidStartNumber,
|
||||||
outpostPk: provider.Pk,
|
outpostName: ls.ac.Outpost.Name,
|
||||||
|
outpostPk: provider.Pk,
|
||||||
}
|
}
|
||||||
if kp := provider.Certificate.Get(); kp != nil {
|
if kp := provider.Certificate.Get(); kp != nil {
|
||||||
err := ls.cs.AddKeypair(*kp)
|
err := ls.cs.AddKeypair(*kp)
|
||||||
|
|
|
@ -11,7 +11,8 @@ type LDAPServerInstance interface {
|
||||||
GetAPIClient() *api.APIClient
|
GetAPIClient() *api.APIClient
|
||||||
GetOutpostName() string
|
GetOutpostName() string
|
||||||
|
|
||||||
GetFlowSlug() string
|
GetAuthenticationFlowSlug() string
|
||||||
|
GetInvalidationFlowSlug() string
|
||||||
GetAppSlug() string
|
GetAppSlug() string
|
||||||
GetSearchAllowedGroups() []*strfmt.UUID
|
GetSearchAllowedGroups() []*strfmt.UUID
|
||||||
|
|
||||||
|
@ -32,7 +33,7 @@ type LDAPServerInstance interface {
|
||||||
UsersForGroup(api.Group) []string
|
UsersForGroup(api.Group) []string
|
||||||
|
|
||||||
GetFlags(dn string) *flags.UserFlags
|
GetFlags(dn string) *flags.UserFlags
|
||||||
SetFlags(dn string, flags flags.UserFlags)
|
SetFlags(dn string, flags *flags.UserFlags)
|
||||||
|
|
||||||
GetBaseEntry() *ldap.Entry
|
GetBaseEntry() *ldap.Entry
|
||||||
GetNeededObjects(int, string, string) (bool, bool)
|
GetNeededObjects(int, string, string) (bool, bool)
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
package ldap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/nmcclain/ldap"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/bind"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (ls *LDAPServer) Unbind(boundDN string, conn net.Conn) (ldap.LDAPResultCode, error) {
|
||||||
|
req, span := bind.NewRequest(boundDN, "", conn)
|
||||||
|
selectedApp := ""
|
||||||
|
defer func() {
|
||||||
|
span.Finish()
|
||||||
|
metrics.Requests.With(prometheus.Labels{
|
||||||
|
"outpost_name": ls.ac.Outpost.Name,
|
||||||
|
"type": "unbind",
|
||||||
|
"app": selectedApp,
|
||||||
|
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
|
||||||
|
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Unbind request")
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.WithError(err.(error)).Error("recover in bind request")
|
||||||
|
sentry.CaptureException(err.(error))
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, instance := range ls.providers {
|
||||||
|
username, err := instance.binder.GetUsername(boundDN)
|
||||||
|
if err == nil {
|
||||||
|
selectedApp = instance.GetAppSlug()
|
||||||
|
return instance.binder.Unbind(username, req)
|
||||||
|
} else {
|
||||||
|
req.Log().WithError(err).Debug("Username not for instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Log().WithField("request", "unbind").Warning("No provider found for request")
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": ls.ac.Outpost.Name,
|
||||||
|
"type": "unbind",
|
||||||
|
"reason": "no_provider",
|
||||||
|
"app": "",
|
||||||
|
}).Inc()
|
||||||
|
return ldap.LDAPResultOperationsError, nil
|
||||||
|
}
|
|
@ -81,7 +81,7 @@ func (ps *ProxyServer) Type() string {
|
||||||
return "proxy"
|
return "proxy"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *ProxyServer) TimerFlowCacheExpiry() {}
|
func (ps *ProxyServer) TimerFlowCacheExpiry(context.Context) {}
|
||||||
|
|
||||||
func (ps *ProxyServer) GetCertificate(serverName string) *tls.Certificate {
|
func (ps *ProxyServer) GetCertificate(serverName string) *tls.Certificate {
|
||||||
app, ok := ps.apps[serverName]
|
app, ok := ps.apps[serverName]
|
||||||
|
@ -163,6 +163,10 @@ func (ps *ProxyServer) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *ProxyServer) Stop() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *ProxyServer) serve(listener net.Listener) {
|
func (ps *ProxyServer) serve(listener net.Listener) {
|
||||||
srv := &http.Server{Handler: ps.mux}
|
srv := &http.Server{Handler: ps.mux}
|
||||||
|
|
||||||
|
|
Reference in New Issue