This repository has been archived on 2024-05-31. You can view files and clone it, but cannot push or open issues or pull requests.
authentik/internal/outpost/ldap/search_route.go

56 lines
1.9 KiB
Go
Raw Permalink Normal View History

package ldap
import (
"strings"
"beryju.io/ldap"
goldap "github.com/go-ldap/ldap/v3"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/search"
)
func (ls *LDAPServer) providerForRequest(req *search.Request) *ProviderInstance {
parsedBaseDN, err := goldap.ParseDN(strings.ToLower(req.BaseDN))
if err != nil {
req.Log().WithError(err).Info("failed to parse base DN")
return nil
}
parsedBindDN, err := goldap.ParseDN(strings.ToLower(req.BindDN))
if err != nil {
req.Log().WithError(err).Info("failed to parse bind DN")
return nil
}
var selectedProvider *ProviderInstance
longestMatch := 0
for _, provider := range ls.providers {
providerBase, _ := goldap.ParseDN(strings.ToLower(provider.BaseDN))
// Try to match the provider primarily based on the search request's base DN
baseDNMatches := providerBase.AncestorOf(parsedBaseDN) || providerBase.Equal(parsedBaseDN)
// But also try to match the provider based on the bind DN
bindDNMatches := providerBase.AncestorOf(parsedBindDN) || providerBase.Equal(parsedBindDN)
if baseDNMatches || bindDNMatches {
// Only select the provider if it's a more precise match than previously
if len(provider.BaseDN) > longestMatch {
req.Log().WithField("provider", provider.BaseDN).Trace("selecting provider for search request")
selectedProvider = provider
longestMatch = len(provider.BaseDN)
}
}
}
return selectedProvider
}
func (ls *LDAPServer) searchRoute(req *search.Request, pi *ProviderInstance) (ldap.ServerSearchResult, error) {
// Route based on the base DN
if len(req.BaseDN) == 0 {
req.Log().Trace("routing to base")
return pi.searcher.SearchBase(req)
}
if strings.EqualFold(req.BaseDN, "cn=subschema") || req.FilterObjectClass == constants.OCSubSchema {
req.Log().Trace("routing to subschema")
return pi.searcher.SearchSubschema(req)
}
req.Log().Trace("routing to default")
return pi.searcher.Search(req)
}