diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3b76df232..dbb0e4d62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,6 +62,8 @@ jobs: python-version: '3.8' - name: Install pyright run: npm install -g pyright + - name: Show pyright version + run: pyright --version - name: Install dependencies run: sudo pip install -U wheel pipenv && pipenv install --dev - name: Lint with pyright diff --git a/passbook/admin/templates/administration/overview.html b/passbook/admin/templates/administration/overview.html index d053f8c3d..21dcb8906 100644 --- a/passbook/admin/templates/administration/overview.html +++ b/passbook/admin/templates/administration/overview.html @@ -55,15 +55,26 @@
- {% if factor_count < 1 %} - {{ factor_count }} + {% if stage_count < 1 %} + {{ stage_count }}

{% trans 'No Stages configured. No Users will be able to login.' %}">

{% else %} - {{ factor_count }} + {{ stage_count }} {% endif %}
+ +
+
+ {% trans 'Flows' %} +
+
+
+ {{ flow_count }} +
+
+
diff --git a/passbook/core/migrations/0002_default_user.py b/passbook/core/migrations/0002_default_user.py new file mode 100644 index 000000000..66e6a2d3e --- /dev/null +++ b/passbook/core/migrations/0002_default_user.py @@ -0,0 +1,28 @@ +# Generated by Django 3.0.6 on 2020-05-23 16:40 + +from django.apps.registry import Apps +from django.db import migrations +from django.db.backends.base.schema import BaseDatabaseSchemaEditor + + +def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): + # User = apps.get_model("passbook_core", "User") + from passbook.core.models import User + + pbadmin = User.objects.create( + username="pbadmin", email="root@localhost", # password="pbadmin" + ) + pbadmin.set_password("pbadmin") # nosec + pbadmin.is_superuser = True + pbadmin.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_core", "0001_initial"), + ] + + operations = [ + migrations.RunPython(create_default_user), + ] diff --git a/passbook/core/signals.py b/passbook/core/signals.py index 01299f90e..74b6b49f1 100644 --- a/passbook/core/signals.py +++ b/passbook/core/signals.py @@ -1,31 +1,7 @@ """passbook core signals""" -from django.core.cache import cache from django.core.signals import Signal -from django.db.models.signals import post_save -from django.dispatch import receiver -from structlog import get_logger - -LOGGER = get_logger() user_signed_up = Signal(providing_args=["request", "user"]) invitation_created = Signal(providing_args=["request", "invitation"]) invitation_used = Signal(providing_args=["request", "invitation", "user"]) password_changed = Signal(providing_args=["user", "password"]) - - -@receiver(post_save) -# pylint: disable=unused-argument -def invalidate_policy_cache(sender, instance, **_): - """Invalidate Policy cache when policy is updated""" - from passbook.policies.models import Policy, PolicyBinding - from passbook.policies.process import cache_key - - if isinstance(instance, Policy): - LOGGER.debug("Invalidating policy cache", policy=instance) - total = 0 - for binding in PolicyBinding.objects.filter(policy=instance): - prefix = cache_key(binding) + "*" - keys = cache.keys(prefix) - total += len(keys) - cache.delete_many(keys) - LOGGER.debug("Deleted keys", len=total) diff --git a/passbook/core/templates/user/settings.html b/passbook/core/templates/user/settings.html index edbcc4928..5e752f935 100644 --- a/passbook/core/templates/user/settings.html +++ b/passbook/core/templates/user/settings.html @@ -17,7 +17,9 @@
diff --git a/passbook/core/views/user.py b/passbook/core/views/user.py index bb575ee8e..4bc02b27f 100644 --- a/passbook/core/views/user.py +++ b/passbook/core/views/user.py @@ -1,4 +1,6 @@ """passbook core user views""" +from typing import Any, Dict + from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.messages.views import SuccessMessageMixin from django.urls import reverse_lazy @@ -6,6 +8,7 @@ from django.utils.translation import gettext as _ from django.views.generic import UpdateView from passbook.core.forms.users import UserDetailForm +from passbook.flows.models import Flow, FlowDesignation class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView): @@ -19,3 +22,11 @@ class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView): def get_object(self): return self.request.user + + def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: + kwargs = super().get_context_data(**kwargs) + unenrollment_flow = Flow.with_policy( + self.request, designation=FlowDesignation.UNRENOLLMENT + ) + kwargs["unenrollment_enabled"] = bool(unenrollment_flow) + return kwargs diff --git a/passbook/crypto/forms.py b/passbook/crypto/forms.py index babf25919..79d5f7100 100644 --- a/passbook/crypto/forms.py +++ b/passbook/crypto/forms.py @@ -34,7 +34,6 @@ class CertificateKeyPairForm(forms.ModelForm): password=None, backend=default_backend(), ) - load_pem_x509_certificate(key_data.encode("utf-8"), default_backend()) except ValueError: raise forms.ValidationError("Unable to load private key.") return key_data diff --git a/passbook/crypto/migrations/0002_create_self_signed_kp.py b/passbook/crypto/migrations/0002_create_self_signed_kp.py new file mode 100644 index 000000000..66239b816 --- /dev/null +++ b/passbook/crypto/migrations/0002_create_self_signed_kp.py @@ -0,0 +1,26 @@ +# Generated by Django 3.0.6 on 2020-05-23 23:07 + +from django.db import migrations + + +def create_self_signed(apps, schema_editor): + CertificateKeyPair = apps.get_model("passbook_crypto", "CertificateKeyPair") + db_alias = schema_editor.connection.alias + from passbook.crypto.builder import CertificateBuilder + + builder = CertificateBuilder() + builder.build() + CertificateKeyPair.objects.using(db_alias).create( + name="passbook Self-signed Certificate", + certificate_data=builder.certificate, + key_data=builder.private_key, + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_crypto", "0001_initial"), + ] + + operations = [migrations.RunPython(create_self_signed)] diff --git a/passbook/flows/models.py b/passbook/flows/models.py index de0147c99..ffb194386 100644 --- a/passbook/flows/models.py +++ b/passbook/flows/models.py @@ -3,12 +3,16 @@ from typing import Optional from uuid import uuid4 from django.db import models +from django.http import HttpRequest from django.utils.translation import gettext_lazy as _ from model_utils.managers import InheritanceManager +from structlog import get_logger from passbook.core.types import UIUserSettings from passbook.policies.models import PolicyBindingModel +LOGGER = get_logger() + class FlowDesignation(models.TextChoices): """Designation of what a Flow should be used for. At a later point, this @@ -62,10 +66,29 @@ class Flow(PolicyBindingModel): PolicyBindingModel, parent_link=True, on_delete=models.CASCADE, related_name="+" ) - def related_flow(self, designation: str) -> Optional["Flow"]: + @staticmethod + def with_policy(request: HttpRequest, **flow_filter) -> Optional["Flow"]: + """Get a Flow by `**flow_filter` and check if the request from `request` can access it.""" + from passbook.policies.engine import PolicyEngine + + flows = Flow.objects.filter(**flow_filter) + for flow in flows: + engine = PolicyEngine(flow, request.user, request) + engine.build() + result = engine.result + if result.passing: + LOGGER.debug("with_policy: flow passing", flow=flow) + return flow + LOGGER.warning( + "with_policy: flow not passing", flow=flow, messages=result.messages + ) + LOGGER.debug("with_policy: no flow found", filters=flow_filter) + return None + + def related_flow(self, designation: str, request: HttpRequest) -> Optional["Flow"]: """Get a related flow with `designation`. Currently this only queries Flows by `designation`, but will eventually use `self` for related lookups.""" - return Flow.objects.filter(designation=designation).first() + return Flow.with_policy(request, designation=designation) def __str__(self) -> str: return f"Flow {self.name} ({self.slug})" diff --git a/passbook/flows/planner.py b/passbook/flows/planner.py index e44e378ee..262279434 100644 --- a/passbook/flows/planner.py +++ b/passbook/flows/planner.py @@ -11,7 +11,6 @@ from passbook.core.models import User from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.models import Flow, Stage from passbook.policies.engine import PolicyEngine -from passbook.policies.types import PolicyResult LOGGER = get_logger() @@ -52,22 +51,12 @@ class FlowPlanner: self.use_cache = True self.flow = flow - def _check_flow_root_policies(self, request: HttpRequest) -> PolicyResult: - engine = PolicyEngine(self.flow, request.user, request) - engine.build() - return engine.result - def plan( self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None ) -> FlowPlan: """Check each of the flows' policies, check policies for each stage with PolicyBinding and return ordered list""" LOGGER.debug("f(plan): Starting planning process", flow=self.flow) - # First off, check the flow's direct policy bindings - # to make sure the user even has access to the flow - root_result = self._check_flow_root_policies(request) - if not root_result.passing: - raise FlowNonApplicableException(*root_result.messages) # Bit of a workaround here, if there is a pending user set in the default context # we use that user for our cache key # to make sure they don't get the generic response @@ -75,6 +64,16 @@ class FlowPlanner: user = default_context[PLAN_CONTEXT_PENDING_USER] else: user = request.user + # First off, check the flow's direct policy bindings + # to make sure the user even has access to the flow + engine = PolicyEngine(self.flow, user, request) + if default_context: + engine.request.context = default_context + engine.build() + result = engine.result + if not result.passing: + raise FlowNonApplicableException(result.messages) + # User is passing so far, check if we have a cached plan cached_plan_key = cache_key(self.flow, user) cached_plan = cache.get(cached_plan_key, None) if cached_plan and self.use_cache: @@ -82,6 +81,7 @@ class FlowPlanner: "f(plan): Taking plan from cache", flow=self.flow, key=cached_plan_key ) return cached_plan + LOGGER.debug("f(plan): building plan", flow=self.flow) plan = self._build_plan(user, request, default_context) cache.set(cache_key(self.flow, user), plan) if not plan.stages: diff --git a/passbook/flows/tests/test_planner.py b/passbook/flows/tests/test_planner.py index df0dc5a7e..af6ae98fa 100644 --- a/passbook/flows/tests/test_planner.py +++ b/passbook/flows/tests/test_planner.py @@ -1,5 +1,5 @@ """flow planner tests""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from django.core.cache import cache from django.shortcuts import reverse @@ -13,7 +13,7 @@ from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache from passbook.policies.types import PolicyResult from passbook.stages.dummy.models import DummyStage -POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False)) +POLICY_RESULT_MOCK = PropertyMock(return_value=PolicyResult(False)) TIME_NOW_MOCK = MagicMock(return_value=3) @@ -40,8 +40,7 @@ class TestFlowPlanner(TestCase): planner.plan(request) @patch( - "passbook.flows.planner.FlowPlanner._check_flow_root_policies", - POLICY_RESULT_MOCK, + "passbook.policies.engine.PolicyEngine.result", POLICY_RESULT_MOCK, ) def test_non_applicable_plan(self): """Test that empty plan raises exception""" diff --git a/passbook/flows/tests/test_views.py b/passbook/flows/tests/test_views.py index e6a2ad20c..cacbe2004 100644 --- a/passbook/flows/tests/test_views.py +++ b/passbook/flows/tests/test_views.py @@ -1,5 +1,5 @@ """flow views tests""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from django.shortcuts import reverse from django.test import Client, TestCase @@ -12,7 +12,7 @@ from passbook.lib.config import CONFIG from passbook.policies.types import PolicyResult from passbook.stages.dummy.models import DummyStage -POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False)) +POLICY_RESULT_MOCK = PropertyMock(return_value=PolicyResult(False)) class TestFlowExecutor(TestCase): @@ -45,8 +45,7 @@ class TestFlowExecutor(TestCase): self.assertEqual(cancel_mock.call_count, 1) @patch( - "passbook.flows.planner.FlowPlanner._check_flow_root_policies", - POLICY_RESULT_MOCK, + "passbook.policies.engine.PolicyEngine.result", POLICY_RESULT_MOCK, ) def test_invalid_non_applicable_flow(self): """Tests that a non-applicable flow returns the correct error message""" diff --git a/passbook/flows/views.py b/passbook/flows/views.py index f955106db..5c8c71f11 100644 --- a/passbook/flows/views.py +++ b/passbook/flows/views.py @@ -1,7 +1,7 @@ """passbook multi-stage authentication engine""" from typing import Any, Dict, Optional -from django.http import HttpRequest, HttpResponse +from django.http import Http404, HttpRequest, HttpResponse from django.shortcuts import get_object_or_404, redirect, reverse from django.utils.decorators import method_decorator from django.views.decorators.clickjacking import xframe_options_sameorigin @@ -164,7 +164,9 @@ class ToDefaultFlow(View): designation: Optional[FlowDesignation] = None def dispatch(self, request: HttpRequest) -> HttpResponse: - flow = get_object_or_404(Flow, designation=self.designation) + flow = Flow.with_policy(request, designation=self.designation) + if not flow: + raise Http404 # If user already has a pending plan, clear it so we don't have to later. if SESSION_KEY_PLAN in self.request.session: plan: FlowPlan = self.request.session[SESSION_KEY_PLAN] diff --git a/passbook/policies/apps.py b/passbook/policies/apps.py index 5795355b6..946f84609 100644 --- a/passbook/policies/apps.py +++ b/passbook/policies/apps.py @@ -1,4 +1,6 @@ """passbook policies app config""" +from importlib import import_module + from django.apps import AppConfig @@ -8,3 +10,7 @@ class PassbookPoliciesConfig(AppConfig): name = "passbook.policies" label = "passbook_policies" verbose_name = "passbook Policies" + + def ready(self): + """Load source_types from config file""" + import_module("passbook.policies.signals") diff --git a/passbook/policies/engine.py b/passbook/policies/engine.py index 143ad6473..5db4f8cfc 100644 --- a/passbook/policies/engine.py +++ b/passbook/policies/engine.py @@ -73,16 +73,20 @@ class PolicyEngine: """Build task group""" for binding in self._iter_bindings(): self._check_policy_type(binding.policy) - policy = binding.policy - cached_policy = cache.get(cache_key(binding, self.request.user), None) + key = cache_key(binding, self.request) + cached_policy = cache.get(key, None) if cached_policy and self.use_cache: - LOGGER.debug("P_ENG: Taking result from cache", policy=policy) + LOGGER.debug( + "P_ENG: Taking result from cache", + policy=binding.policy, + cache_key=key, + ) self.__cached_policies.append(cached_policy) continue - LOGGER.debug("P_ENG: Evaluating policy", policy=policy) + LOGGER.debug("P_ENG: Evaluating policy", policy=binding.policy) our_end, task_end = Pipe(False) task = PolicyProcess(binding, self.request, task_end) - LOGGER.debug("P_ENG: Starting Process", policy=policy) + LOGGER.debug("P_ENG: Starting Process", policy=binding.policy) task.start() self.__processes.append( PolicyProcessInfo(process=task, connection=our_end, binding=binding) @@ -103,7 +107,9 @@ class PolicyEngine: x.result for x in self.__processes if x.result ] for result in process_results + self.__cached_policies: - LOGGER.debug("P_ENG: result", passing=result.passing) + LOGGER.debug( + "P_ENG: result", passing=result.passing, messages=result.messages + ) if result.messages: messages += result.messages if not result.passing: diff --git a/passbook/policies/process.py b/passbook/policies/process.py index 1fb906c9f..a187627a6 100644 --- a/passbook/policies/process.py +++ b/passbook/policies/process.py @@ -6,7 +6,6 @@ from typing import Optional from django.core.cache import cache from structlog import get_logger -from passbook.core.models import User from passbook.policies.exceptions import PolicyException from passbook.policies.models import PolicyBinding from passbook.policies.types import PolicyRequest, PolicyResult @@ -14,11 +13,13 @@ from passbook.policies.types import PolicyRequest, PolicyResult LOGGER = get_logger() -def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str: +def cache_key(binding: PolicyBinding, request: PolicyRequest) -> str: """Generate Cache key for policy""" prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}" - if user: - prefix += f"#{user.pk}" + if request.http_request: + prefix += f"_{request.http_request.session.session_key}" + if request.user: + prefix += f"#{request.user.pk}" return prefix @@ -65,7 +66,7 @@ class PolicyProcess(Process): passing=policy_result.passing, user=self.request.user, ) - key = cache_key(self.binding, self.request.user) + key = cache_key(self.binding, self.request) cache.set(key, policy_result) LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key) return policy_result diff --git a/passbook/policies/signals.py b/passbook/policies/signals.py new file mode 100644 index 000000000..82e0b3d94 --- /dev/null +++ b/passbook/policies/signals.py @@ -0,0 +1,26 @@ +"""passbook policy signals""" +from django.core.cache import cache +from django.db.models.signals import post_save +from django.dispatch import receiver +from structlog import get_logger + +LOGGER = get_logger() + + +@receiver(post_save) +# pylint: disable=unused-argument +def invalidate_policy_cache(sender, instance, **_): + """Invalidate Policy cache when policy is updated""" + from passbook.policies.models import Policy, PolicyBinding + + if isinstance(instance, Policy): + LOGGER.debug("Invalidating policy cache", policy=instance) + total = 0 + for binding in PolicyBinding.objects.filter(policy=instance): + prefix = ( + f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}*" + ) + keys = cache.keys(prefix) + total += len(keys) + cache.delete_many(keys) + LOGGER.debug("Deleted keys", len=total) diff --git a/passbook/providers/saml/migrations/0002_default_saml_property_mappings.py b/passbook/providers/saml/migrations/0002_default_saml_property_mappings.py new file mode 100644 index 000000000..72575b6d6 --- /dev/null +++ b/passbook/providers/saml/migrations/0002_default_saml_property_mappings.py @@ -0,0 +1,63 @@ +# Generated by Django 3.0.6 on 2020-05-23 19:32 + +from django.db import migrations + + +def create_default_property_mappings(apps, schema_editor): + """Create default SAML Property Mappings""" + SAMLPropertyMapping = apps.get_model( + "passbook_providers_saml", "SAMLPropertyMapping" + ) + db_alias = schema_editor.connection.alias + defaults = [ + { + "FriendlyName": "eduPersonPrincipalName", + "Name": "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", + "Expression": "{{ user.email }}", + }, + { + "FriendlyName": "cn", + "Name": "urn:oid:2.5.4.3", + "Expression": "{{ user.name }}", + }, + { + "FriendlyName": "mail", + "Name": "urn:oid:0.9.2342.19200300.100.1.3", + "Expression": "{{ user.email }}", + }, + { + "FriendlyName": "displayName", + "Name": "urn:oid:2.16.840.1.113730.3.1.241", + "Expression": "{{ user.username }}", + }, + { + "FriendlyName": "uid", + "Name": "urn:oid:0.9.2342.19200300.100.1.1", + "Expression": "{{ user.pk }}", + }, + { + "FriendlyName": "member-of", + "Name": "member-of", + "Expression": "[{% for group in user.groups.all() %}'{{ group.name }}',{% endfor %}]", + }, + ] + for default in defaults: + SAMLPropertyMapping.objects.using(db_alias).get_or_create( + saml_name=default["Name"], + friendly_name=default["FriendlyName"], + expression=default["Expression"], + defaults={ + "name": f"Autogenerated SAML Mapping: {default['FriendlyName']} -> {default['Expression']}" + }, + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_providers_saml", "0001_initial"), + ] + + operations = [ + migrations.RunPython(create_default_property_mappings), + ] diff --git a/passbook/sources/ldap/api.py b/passbook/sources/ldap/api.py index a51a5ce12..e5ad2677c 100644 --- a/passbook/sources/ldap/api.py +++ b/passbook/sources/ldap/api.py @@ -23,6 +23,7 @@ class LDAPSourceSerializer(ModelSerializer): "group_object_filter", "user_group_membership_field", "object_uniqueness_field", + "sync_users", "sync_groups", "sync_parent_group", "property_mappings", diff --git a/passbook/sources/ldap/connector.py b/passbook/sources/ldap/connector.py index 064a6a628..748c25e9e 100644 --- a/passbook/sources/ldap/connector.py +++ b/passbook/sources/ldap/connector.py @@ -16,26 +16,10 @@ LOGGER = get_logger() class Connector: """Wrapper for ldap3 to easily manage user authentication and creation""" - _server: ldap3.Server - _connection = ldap3.Connection _source: LDAPSource def __init__(self, source: LDAPSource): self._source = source - self._server = ldap3.Server(source.server_uri) # Implement URI parsing - - def bind(self): - """Bind using Source's Credentials""" - self._connection = ldap3.Connection( - self._server, - raise_exceptions=True, - user=self._source.bind_cn, - password=self._source.bind_password, - ) - - self._connection.bind() - if self._source.start_tls: - self._connection.start_tls() @staticmethod def encode_pass(password: str) -> bytes: @@ -45,19 +29,23 @@ class Connector: @property def base_dn_users(self) -> str: """Shortcut to get full base_dn for user lookups""" - return ",".join([self._source.additional_user_dn, self._source.base_dn]) + if self._source.additional_user_dn: + return f"{self._source.additional_user_dn},{self._source.base_dn}" + return self._source.base_dn @property def base_dn_groups(self) -> str: """Shortcut to get full base_dn for group lookups""" - return ",".join([self._source.additional_group_dn, self._source.base_dn]) + if self._source.additional_group_dn: + return f"{self._source.additional_group_dn},{self._source.base_dn}" + return self._source.base_dn def sync_groups(self): """Iterate over all LDAP Groups and create passbook_core.Group instances""" if not self._source.sync_groups: - LOGGER.debug("Group syncing is disabled for this Source") + LOGGER.warning("Group syncing is disabled for this Source") return - groups = self._connection.extend.standard.paged_search( + groups = self._source.connection.extend.standard.paged_search( search_base=self.base_dn_groups, search_filter=self._source.group_object_filter, search_scope=ldap3.SUBTREE, @@ -87,7 +75,10 @@ class Connector: def sync_users(self): """Iterate over all LDAP Users and create passbook_core.User instances""" - users = self._connection.extend.standard.paged_search( + if not self._source.sync_users: + LOGGER.warning("User syncing is disabled for this Source") + return + users = self._source.connection.extend.standard.paged_search( search_base=self.base_dn_users, search_filter=self._source.user_object_filter, search_scope=ldap3.SUBTREE, @@ -101,9 +92,9 @@ class Connector: LOGGER.warning("Cannot find uniqueness Field in attributes") continue try: + defaults = self._build_object_properties(attributes) user, created = User.objects.update_or_create( - attributes__ldap_uniq=uniq, - defaults=self._build_object_properties(attributes), + attributes__ldap_uniq=uniq, defaults=defaults, ) except IntegrityError as exc: LOGGER.warning("Failed to create user", exc=exc) @@ -123,7 +114,7 @@ class Connector: def sync_membership(self): """Iterate over all Users and assign Groups using memberOf Field""" - users = self._connection.extend.standard.paged_search( + users = self._source.connection.extend.standard.paged_search( search_base=self.base_dn_users, search_filter=self._source.user_object_filter, search_scope=ldap3.SUBTREE, @@ -220,7 +211,7 @@ class Connector: LOGGER.debug("Attempting Binding as user", user=user) try: temp_connection = ldap3.Connection( - self._server, + self._source.connection.server, user=user.attributes.get("distinguishedName"), password=password, raise_exceptions=True, diff --git a/passbook/sources/ldap/forms.py b/passbook/sources/ldap/forms.py index 249ebd5af..48d71d48a 100644 --- a/passbook/sources/ldap/forms.py +++ b/passbook/sources/ldap/forms.py @@ -26,6 +26,7 @@ class LDAPSourceForm(forms.ModelForm): "group_object_filter", "user_group_membership_field", "object_uniqueness_field", + "sync_users", "sync_groups", "sync_parent_group", "property_mappings", diff --git a/passbook/sources/ldap/migrations/0002_ldapsource_sync_users.py b/passbook/sources/ldap/migrations/0002_ldapsource_sync_users.py new file mode 100644 index 000000000..27a0da2b3 --- /dev/null +++ b/passbook/sources/ldap/migrations/0002_ldapsource_sync_users.py @@ -0,0 +1,18 @@ +# Generated by Django 3.0.6 on 2020-05-23 19:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_sources_ldap", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="ldapsource", + name="sync_users", + field=models.BooleanField(default=True), + ), + ] diff --git a/passbook/sources/ldap/migrations/0003_default_ldap_property_mappings.py b/passbook/sources/ldap/migrations/0003_default_ldap_property_mappings.py new file mode 100644 index 000000000..318952211 --- /dev/null +++ b/passbook/sources/ldap/migrations/0003_default_ldap_property_mappings.py @@ -0,0 +1,35 @@ +# Generated by Django 3.0.6 on 2020-05-23 19:30 + +from django.apps.registry import Apps +from django.db import migrations + + +def create_default_ad_property_mappings(apps: Apps, schema_editor): + LDAPPropertyMapping = apps.get_model("passbook_sources_ldap", "LDAPPropertyMapping") + mapping = { + "name": "{{ ldap.name }}", + "first_name": "{{ ldap.givenName }}", + "last_name": "{{ ldap.sn }}", + "username": "{{ ldap.sAMAccountName }}", + "email": "{{ ldap.mail }}", + } + db_alias = schema_editor.connection.alias + for object_field, expression in mapping.items(): + LDAPPropertyMapping.objects.using(db_alias).get_or_create( + expression=expression, + object_field=object_field, + defaults={ + "name": f"Autogenerated LDAP Mapping: {expression} -> {object_field}" + }, + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ("passbook_sources_ldap", "0002_ldapsource_sync_users"), + ] + + operations = [ + migrations.RunPython(create_default_ad_property_mappings), + ] diff --git a/passbook/sources/ldap/models.py b/passbook/sources/ldap/models.py index 393cccfa2..34fa96e56 100644 --- a/passbook/sources/ldap/models.py +++ b/passbook/sources/ldap/models.py @@ -1,8 +1,10 @@ """passbook LDAP Models""" +from typing import Optional from django.core.validators import URLValidator from django.db import models from django.utils.translation import gettext_lazy as _ +from ldap3 import Connection, Server from passbook.core.models import Group, PropertyMapping, Source @@ -22,10 +24,12 @@ class LDAPSource(Source): additional_user_dn = models.TextField( help_text=_("Prepended to Base DN for User-queries."), verbose_name=_("Addition User DN"), + blank=True, ) additional_group_dn = models.TextField( help_text=_("Prepended to Base DN for Group-queries."), verbose_name=_("Addition Group DN"), + blank=True, ) user_object_filter = models.TextField( @@ -43,6 +47,7 @@ class LDAPSource(Source): default="objectSid", help_text=_("Field which contains a unique Identifier.") ) + sync_users = models.BooleanField(default=True) sync_groups = models.BooleanField(default=True) sync_parent_group = models.ForeignKey( Group, blank=True, null=True, default=None, on_delete=models.SET_DEFAULT @@ -50,6 +55,25 @@ class LDAPSource(Source): form = "passbook.sources.ldap.forms.LDAPSourceForm" + _connection: Optional[Connection] + + @property + def connection(self) -> Connection: + """Get a fully connected and bound LDAP Connection""" + if not self._connection: + server = Server(self.server_uri) + self._connection = Connection( + server, + raise_exceptions=True, + user=self.bind_cn, + password=self.bind_password, + ) + + self._connection.bind() + if self.start_tls: + self._connection.start_tls() + return self._connection + class Meta: verbose_name = _("LDAP Source") diff --git a/passbook/sources/ldap/tasks.py b/passbook/sources/ldap/tasks.py index 581d27c7a..eeb1cb282 100644 --- a/passbook/sources/ldap/tasks.py +++ b/passbook/sources/ldap/tasks.py @@ -9,7 +9,6 @@ def sync_groups(source_pk: int): """Sync LDAP Groups on background worker""" source = LDAPSource.objects.get(pk=source_pk) connector = Connector(source) - connector.bind() connector.sync_groups() @@ -18,7 +17,6 @@ def sync_users(source_pk: int): """Sync LDAP Users on background worker""" source = LDAPSource.objects.get(pk=source_pk) connector = Connector(source) - connector.bind() connector.sync_users() @@ -27,7 +25,6 @@ def sync(): """Sync all sources""" for source in LDAPSource.objects.filter(enabled=True): connector = Connector(source) - connector.bind() connector.sync_users() connector.sync_groups() connector.sync_membership() diff --git a/passbook/sources/ldap/tests.py b/passbook/sources/ldap/tests.py new file mode 100644 index 000000000..faa3f4177 --- /dev/null +++ b/passbook/sources/ldap/tests.py @@ -0,0 +1,75 @@ +"""LDAP Source tests""" +from unittest.mock import PropertyMock, patch + +from django.test import TestCase +from ldap3 import MOCK_SYNC, OFFLINE_AD_2012_R2, Connection, Server + +from passbook.core.models import User +from passbook.sources.ldap.connector import Connector +from passbook.sources.ldap.models import LDAPPropertyMapping, LDAPSource + + +def _build_mock_connection() -> Connection: + """Create mock connection""" + server = Server("my_fake_server", get_info=OFFLINE_AD_2012_R2) + _pass = "foo" # noqa # nosec + connection = Connection( + server, + user="cn=my_user,ou=test,o=lab", + password=_pass, + client_strategy=MOCK_SYNC, + ) + connection.strategy.add_entry( + "cn=user0,ou=test,o=lab", + { + "userPassword": "test0000", + "sAMAccountName": "user0_sn", + "revision": 0, + "objectSid": "unique-test0000", + "objectCategory": "Person", + }, + ) + connection.strategy.add_entry( + "cn=user1,ou=test,o=lab", + { + "userPassword": "test1111", + "sAMAccountName": "user1_sn", + "revision": 0, + "objectSid": "unique-test1111", + "objectCategory": "Person", + }, + ) + connection.strategy.add_entry( + "cn=user2,ou=test,o=lab", + { + "userPassword": "test2222", + "sAMAccountName": "user2_sn", + "revision": 0, + "objectSid": "unique-test2222", + "objectCategory": "Person", + }, + ) + connection.bind() + return connection + + +LDAP_CONNECTION_PATCH = PropertyMock(return_value=_build_mock_connection()) + + +class LDAPSourceTests(TestCase): + """LDAP Source tests""" + + def setUp(self): + self.source = LDAPSource.objects.create( + name="ldap", slug="ldap", base_dn="o=lab" + ) + self.source.property_mappings.set(LDAPPropertyMapping.objects.all()) + self.source.save() + + @patch("passbook.sources.ldap.models.LDAPSource.connection", LDAP_CONNECTION_PATCH) + def test_sync_users(self): + """Test user sync""" + connector = Connector(self.source) + connector.sync_users() + user = User.objects.filter(username="user2_sn") + self.assertTrue(user.exists()) diff --git a/passbook/sources/oauth/clients.py b/passbook/sources/oauth/clients.py index 06a9310fc..35c58d7ba 100644 --- a/passbook/sources/oauth/clients.py +++ b/passbook/sources/oauth/clients.py @@ -1,6 +1,6 @@ """OAuth Clients""" import json -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from urllib.parse import parse_qs, urlencode from django.http import HttpRequest @@ -14,24 +14,29 @@ from structlog import get_logger from passbook import __version__ LOGGER = get_logger() +if TYPE_CHECKING: + from passbook.sources.oauth.models import OAuthSource class BaseOAuthClient: """Base OAuth Client""" session: Session + source: "OAuthSource" - def __init__(self, source, token=""): # nosec + def __init__(self, source: "OAuthSource", token=""): # nosec self.source = source self.token = token self.session = Session() self.session.headers.update({"User-Agent": "passbook %s" % __version__}) - def get_access_token(self, request, callback=None): + def get_access_token( + self, request: HttpRequest, callback=None + ) -> Optional[Dict[str, Any]]: "Fetch access token from callback request." raise NotImplementedError("Defined in a sub-class") # pragma: no cover - def get_profile_info(self, token: Dict[str, str]): + def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]: "Fetch user profile information." try: headers = { @@ -45,7 +50,7 @@ class BaseOAuthClient: LOGGER.warning("Unable to fetch user profile", exc=exc) return None else: - return response.json() or response.text + return response.json() def get_redirect_args(self, request, callback) -> Dict[str, str]: "Get request parameters for redirect url." diff --git a/passbook/sources/oauth/views/core.py b/passbook/sources/oauth/views/core.py index 7e3249bc7..9166ad3b6 100644 --- a/passbook/sources/oauth/views/core.py +++ b/passbook/sources/oauth/views/core.py @@ -21,7 +21,7 @@ from passbook.flows.planner import ( ) from passbook.flows.views import SESSION_KEY_PLAN from passbook.lib.utils.urls import redirect_with_qs -from passbook.sources.oauth.clients import get_client +from passbook.sources.oauth.clients import BaseOAuthClient, get_client from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND @@ -34,7 +34,7 @@ class OAuthClientMixin: client_class: Optional[Callable] = None - def get_client(self, source): + def get_client(self, source: OAuthSource) -> BaseOAuthClient: "Get instance of the OAuth client for this source." if self.client_class is not None: # pylint: disable=not-callable diff --git a/passbook/stages/identification/forms.py b/passbook/stages/identification/forms.py index 04217f7f8..882ce0f03 100644 --- a/passbook/stages/identification/forms.py +++ b/passbook/stages/identification/forms.py @@ -16,7 +16,7 @@ class IdentificationStageForm(forms.ModelForm): class Meta: model = IdentificationStage - fields = ["name", "user_fields", "template"] + fields = ["name", "user_fields", "template", "enrollment_flow", "recovery_flow"] widgets = { "name": forms.TextInput(), } diff --git a/passbook/stages/user_write/tests.py b/passbook/stages/user_write/tests.py index 5bad06809..d37012207 100644 --- a/passbook/stages/user_write/tests.py +++ b/passbook/stages/user_write/tests.py @@ -72,6 +72,7 @@ class TestUserWriteStage(TestCase): plan.context[PLAN_CONTEXT_PROMPT] = { "username": "test-user-new", "password": new_password, + "some-custom-attribute": "test", } session = self.client.session session[SESSION_KEY_PLAN] = plan @@ -88,6 +89,7 @@ class TestUserWriteStage(TestCase): ) self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.first().check_password(new_password)) + self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test") def test_without_data(self): """Test without data results in error""" diff --git a/swagger.yaml b/swagger.yaml index ac38448f0..31c87f5be 100755 --- a/swagger.yaml +++ b/swagger.yaml @@ -5606,8 +5606,6 @@ definitions: - bind_cn - bind_password - base_dn - - additional_user_dn - - additional_group_dn type: object properties: pk: @@ -5654,12 +5652,10 @@ definitions: title: Addition User DN description: Prepended to Base DN for User-queries. type: string - minLength: 1 additional_group_dn: title: Addition Group DN description: Prepended to Base DN for Group-queries. type: string - minLength: 1 user_object_filter: title: User object filter description: Consider Objects matching this filter to be Users. @@ -5680,6 +5676,9 @@ definitions: description: Field which contains a unique Identifier. type: string minLength: 1 + sync_users: + title: Sync users + type: boolean sync_groups: title: Sync groups type: boolean