diff --git a/authentik/lib/config.py b/authentik/lib/config.py index 043f77460..63aa3493a 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -24,7 +24,7 @@ ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local") def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any: - """Recursively walk through `root`, checking each part of `path` split by `sep`. + """Recursively walk through `root`, checking each part of `path` separated by `sep`. If at any point a dict does not exist, return default""" for comp in path.split(sep): if root and comp in root: @@ -34,7 +34,19 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any: return root -@dataclass +def set_path_in_dict(root: dict, path: str, value: Any, sep="."): + """Recursively walk through `root`, checking each part of `path` separated by `sep` + and setting the last value to `value`""" + # Walk each component of the path + path_parts = path.split(sep) + for comp in path_parts[:-1]: + if comp not in root: + root[comp] = {} + root = root.get(comp, {}) + root[path_parts[-1]] = value + + +@dataclass(slots=True) class Attr: """Single configuration attribute""" @@ -55,6 +67,10 @@ class Attr: # to the config file containing this change or the file containing this value source: Optional[str] = field(default=None) + def __post_init__(self): + if isinstance(self.value, Attr): + raise RuntimeError(f"config Attr with nested Attr for source {self.source}") + class AttrEncoder(JSONEncoder): """JSON encoder that can deal with `Attr` classes""" @@ -227,15 +243,7 @@ class ConfigLoader: def set(self, path: str, value: Any, sep="."): """Set value using same syntax as get()""" - # Walk sub_dicts before parsing path - root = self.raw - # Walk each component of the path - path_parts = path.split(sep) - for comp in path_parts[:-1]: - if comp not in root: - root[comp] = {} - root = root.get(comp, {}) - root[path_parts[-1]] = Attr(value) + set_path_in_dict(self.raw, path, Attr(value), sep=sep) CONFIG = ConfigLoader() diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index a131d935c..7490449ec 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -9,7 +9,7 @@ from structlog.stdlib import BoundLogger, get_logger from authentik.core.exceptions import PropertyMappingExpressionException from authentik.events.models import Event, EventAction -from authentik.lib.config import CONFIG +from authentik.lib.config import CONFIG, set_path_in_dict from authentik.lib.merge import MERGE_LIST_UNIQUE from authentik.sources.ldap.auth import LDAP_DISTINGUISHED_NAME from authentik.sources.ldap.models import LDAPPropertyMapping, LDAPSource @@ -164,7 +164,7 @@ class BaseLDAPSynchronizer: if object_field.startswith("attributes."): # Because returning a list might desired, we can't # rely on self._flatten here. Instead, just save the result as-is - properties["attributes"][object_field.replace("attributes.", "")] = value + set_path_in_dict(properties, object_field, value) else: properties[object_field] = self._flatten(value) except PropertyMappingExpressionException as exc: diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index 98494fae1..5a4c80974 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -14,6 +14,7 @@ from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import StageView from authentik.flows.views.executor import FlowExecutorView +from authentik.lib.config import set_path_in_dict from authentik.stages.password import BACKEND_INBUILT from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT @@ -44,12 +45,7 @@ class UserWriteStageView(StageView): # this is just a sanity check to ensure that is removed if parts[0] == "attributes": parts = parts[1:] - attrs = user.attributes - for comp in parts[:-1]: - if comp not in attrs: - attrs[comp] = {} - attrs = attrs.get(comp) - attrs[parts[-1]] = value + set_path_in_dict(user.attributes, ".".join(parts), value) def ensure_user(self) -> tuple[Optional[User], bool]: """Ensure a user exists"""