package ldap import ( "context" "errors" "fmt" "strings" "sync" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" "goauthentik.io/internal/outpost/ldap/bind" directbind "goauthentik.io/internal/outpost/ldap/bind/direct" memorybind "goauthentik.io/internal/outpost/ldap/bind/memory" "goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/flags" directsearch "goauthentik.io/internal/outpost/ldap/search/direct" memorysearch "goauthentik.io/internal/outpost/ldap/search/memory" ) func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { for _, p := range ls.providers { if p.outpostPk == pk { return p } } return nil } func (ls *LDAPServer) getInvalidationFlow() string { req, _, err := ls.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute() if err != nil { ls.log.WithError(err).Warning("failed to fetch brand config") return "" } flow := req.GetFlowInvalidation() return flow } func (ls *LDAPServer) Refresh() error { outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() if err != nil { return err } if len(outposts.Results) < 1 { return errors.New("no ldap provider defined") } providers := make([]*ProviderInstance, len(outposts.Results)) invalidationFlow := ls.getInvalidationFlow() for idx, provider := range outposts.Results { userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn)) groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn)) logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name) // Get existing instance so we can transfer boundUsers existing := ls.getCurrentProvider(provider.Pk) usersMutex := &sync.RWMutex{} users := make(map[string]*flags.UserFlags) if existing != nil { usersMutex = existing.boundUsersMutex // Shallow copy, no need to lock users = existing.boundUsers } providers[idx] = &ProviderInstance{ BaseDN: provider.GetBaseDn(), VirtualGroupDN: virtualGroupDN, GroupDN: groupDN, UserDN: userDN, appSlug: provider.ApplicationSlug, authenticationFlowSlug: provider.BindFlowSlug, invalidationFlowSlug: invalidationFlow, searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, boundUsersMutex: usersMutex, boundUsers: users, s: ls, log: logger, tlsServerName: provider.TlsServerName, uidStartNumber: provider.GetUidStartNumber(), gidStartNumber: provider.GetGidStartNumber(), mfaSupport: provider.GetMfaSupport(), outpostName: ls.ac.Outpost.Name, outpostPk: provider.Pk, } if kp := provider.Certificate.Get(); kp != nil { err := ls.cs.AddKeypair(*kp) if err != nil { ls.log.WithError(err).Warning("Failed to initially fetch certificate") } providers[idx].cert = ls.cs.Get(*kp) providers[idx].certUUID = *kp } if *provider.SearchMode.Ptr() == api.LDAPAPIACCESSMODE_CACHED { providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx]) } else if *provider.SearchMode.Ptr() == api.LDAPAPIACCESSMODE_DIRECT { providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx]) } if *provider.BindMode.Ptr() == api.LDAPAPIACCESSMODE_CACHED { var oldBinder bind.Binder if existing != nil { oldBinder = existing.binder } providers[idx].binder = memorybind.NewSessionBinder(providers[idx], oldBinder) } else if *provider.BindMode.Ptr() == api.LDAPAPIACCESSMODE_DIRECT { providers[idx].binder = directbind.NewDirectBinder(providers[idx]) } } ls.providers = providers ls.log.Info("Update providers") return nil }