Merge branch 'master' into docs-flows

This commit is contained in:
Jens Langhammer 2020-06-02 20:25:43 +02:00
commit c4facd53b4
31 changed files with 430 additions and 102 deletions

View File

@ -62,6 +62,8 @@ jobs:
python-version: '3.8' python-version: '3.8'
- name: Install pyright - name: Install pyright
run: npm install -g pyright run: npm install -g pyright
- name: Show pyright version
run: pyright --version
- name: Install dependencies - name: Install dependencies
run: sudo pip install -U wheel pipenv && pipenv install --dev run: sudo pip install -U wheel pipenv && pipenv install --dev
- name: Lint with pyright - name: Lint with pyright

View File

@ -55,15 +55,26 @@
</div> </div>
</div> </div>
<div class="pf-c-card__body"> <div class="pf-c-card__body">
{% if factor_count < 1 %} {% if stage_count < 1 %}
<i class="pficon-error-circle-o"></i> {{ factor_count }} <i class="pficon-error-circle-o"></i> {{ stage_count }}
<p>{% trans 'No Stages configured. No Users will be able to login.' %}"></p> <p>{% trans 'No Stages configured. No Users will be able to login.' %}"></p>
{% else %} {% else %}
<i class="pf-icon pf-icon-ok"></i> {{ factor_count }} <i class="pf-icon pf-icon-ok"></i> {{ stage_count }}
{% endif %} {% endif %}
</div> </div>
</a> </a>
<a href="{% url 'passbook_admin:stages' %}" class="pf-c-card pf-m-hoverable pf-m-compact">
<div class="pf-c-card__head">
<div class="pf-c-card__head-main">
<i class="pf-icon pf-icon-topology"></i> {% trans 'Flows' %}
</div>
</div>
<div class="pf-c-card__body">
<i class="pf-icon pf-icon-ok"></i> {{ flow_count }}
</div>
</a>
<a href="{% url 'passbook_admin:policies' %}" class="pf-c-card pf-m-hoverable pf-m-compact"> <a href="{% url 'passbook_admin:policies' %}" class="pf-c-card pf-m-hoverable pf-m-compact">
<div class="pf-c-card__head"> <div class="pf-c-card__head">
<div class="pf-c-card__head-main"> <div class="pf-c-card__head-main">

View File

@ -0,0 +1,28 @@
# Generated by Django 3.0.6 on 2020-05-23 16:40
from django.apps.registry import Apps
from django.db import migrations
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
# User = apps.get_model("passbook_core", "User")
from passbook.core.models import User
pbadmin = User.objects.create(
username="pbadmin", email="root@localhost", # password="pbadmin"
)
pbadmin.set_password("pbadmin") # nosec
pbadmin.is_superuser = True
pbadmin.save()
class Migration(migrations.Migration):
dependencies = [
("passbook_core", "0001_initial"),
]
operations = [
migrations.RunPython(create_default_user),
]

View File

@ -1,31 +1,7 @@
"""passbook core signals""" """passbook core signals"""
from django.core.cache import cache
from django.core.signals import Signal from django.core.signals import Signal
from django.db.models.signals import post_save
from django.dispatch import receiver
from structlog import get_logger
LOGGER = get_logger()
user_signed_up = Signal(providing_args=["request", "user"]) user_signed_up = Signal(providing_args=["request", "user"])
invitation_created = Signal(providing_args=["request", "invitation"]) invitation_created = Signal(providing_args=["request", "invitation"])
invitation_used = Signal(providing_args=["request", "invitation", "user"]) invitation_used = Signal(providing_args=["request", "invitation", "user"])
password_changed = Signal(providing_args=["user", "password"]) password_changed = Signal(providing_args=["user", "password"])
@receiver(post_save)
# pylint: disable=unused-argument
def invalidate_policy_cache(sender, instance, **_):
"""Invalidate Policy cache when policy is updated"""
from passbook.policies.models import Policy, PolicyBinding
from passbook.policies.process import cache_key
if isinstance(instance, Policy):
LOGGER.debug("Invalidating policy cache", policy=instance)
total = 0
for binding in PolicyBinding.objects.filter(policy=instance):
prefix = cache_key(binding) + "*"
keys = cache.keys(prefix)
total += len(keys)
cache.delete_many(keys)
LOGGER.debug("Deleted keys", len=total)

View File

@ -17,7 +17,9 @@
<div class="pf-c-form__horizontal-group"> <div class="pf-c-form__horizontal-group">
<div class="pf-c-form__actions"> <div class="pf-c-form__actions">
<input class="pf-c-button pf-m-primary" type="submit" value="{% trans 'Update' %}" /> <input class="pf-c-button pf-m-primary" type="submit" value="{% trans 'Update' %}" />
{% if unenrollment_enabled %}
<a class="pf-c-button pf-m-danger" href="{% url 'passbook_flows:default-unenrollment' %}?back={{ request.get_full_path }}">{% trans "Delete account" %}</a> <a class="pf-c-button pf-m-danger" href="{% url 'passbook_flows:default-unenrollment' %}?back={{ request.get_full_path }}">{% trans "Delete account" %}</a>
{% endif %}
</div> </div>
</div> </div>
</div> </div>

View File

@ -1,4 +1,6 @@
"""passbook core user views""" """passbook core user views"""
from typing import Any, Dict
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.messages.views import SuccessMessageMixin from django.contrib.messages.views import SuccessMessageMixin
from django.urls import reverse_lazy from django.urls import reverse_lazy
@ -6,6 +8,7 @@ from django.utils.translation import gettext as _
from django.views.generic import UpdateView from django.views.generic import UpdateView
from passbook.core.forms.users import UserDetailForm from passbook.core.forms.users import UserDetailForm
from passbook.flows.models import Flow, FlowDesignation
class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView): class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
@ -19,3 +22,11 @@ class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
def get_object(self): def get_object(self):
return self.request.user return self.request.user
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
kwargs = super().get_context_data(**kwargs)
unenrollment_flow = Flow.with_policy(
self.request, designation=FlowDesignation.UNRENOLLMENT
)
kwargs["unenrollment_enabled"] = bool(unenrollment_flow)
return kwargs

View File

@ -34,7 +34,6 @@ class CertificateKeyPairForm(forms.ModelForm):
password=None, password=None,
backend=default_backend(), backend=default_backend(),
) )
load_pem_x509_certificate(key_data.encode("utf-8"), default_backend())
except ValueError: except ValueError:
raise forms.ValidationError("Unable to load private key.") raise forms.ValidationError("Unable to load private key.")
return key_data return key_data

View File

@ -0,0 +1,26 @@
# Generated by Django 3.0.6 on 2020-05-23 23:07
from django.db import migrations
def create_self_signed(apps, schema_editor):
CertificateKeyPair = apps.get_model("passbook_crypto", "CertificateKeyPair")
db_alias = schema_editor.connection.alias
from passbook.crypto.builder import CertificateBuilder
builder = CertificateBuilder()
builder.build()
CertificateKeyPair.objects.using(db_alias).create(
name="passbook Self-signed Certificate",
certificate_data=builder.certificate,
key_data=builder.private_key,
)
class Migration(migrations.Migration):
dependencies = [
("passbook_crypto", "0001_initial"),
]
operations = [migrations.RunPython(create_self_signed)]

View File

@ -3,12 +3,16 @@ from typing import Optional
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.http import HttpRequest
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from model_utils.managers import InheritanceManager from model_utils.managers import InheritanceManager
from structlog import get_logger
from passbook.core.types import UIUserSettings from passbook.core.types import UIUserSettings
from passbook.policies.models import PolicyBindingModel from passbook.policies.models import PolicyBindingModel
LOGGER = get_logger()
class FlowDesignation(models.TextChoices): class FlowDesignation(models.TextChoices):
"""Designation of what a Flow should be used for. At a later point, this """Designation of what a Flow should be used for. At a later point, this
@ -62,10 +66,29 @@ class Flow(PolicyBindingModel):
PolicyBindingModel, parent_link=True, on_delete=models.CASCADE, related_name="+" PolicyBindingModel, parent_link=True, on_delete=models.CASCADE, related_name="+"
) )
def related_flow(self, designation: str) -> Optional["Flow"]: @staticmethod
def with_policy(request: HttpRequest, **flow_filter) -> Optional["Flow"]:
"""Get a Flow by `**flow_filter` and check if the request from `request` can access it."""
from passbook.policies.engine import PolicyEngine
flows = Flow.objects.filter(**flow_filter)
for flow in flows:
engine = PolicyEngine(flow, request.user, request)
engine.build()
result = engine.result
if result.passing:
LOGGER.debug("with_policy: flow passing", flow=flow)
return flow
LOGGER.warning(
"with_policy: flow not passing", flow=flow, messages=result.messages
)
LOGGER.debug("with_policy: no flow found", filters=flow_filter)
return None
def related_flow(self, designation: str, request: HttpRequest) -> Optional["Flow"]:
"""Get a related flow with `designation`. Currently this only queries """Get a related flow with `designation`. Currently this only queries
Flows by `designation`, but will eventually use `self` for related lookups.""" Flows by `designation`, but will eventually use `self` for related lookups."""
return Flow.objects.filter(designation=designation).first() return Flow.with_policy(request, designation=designation)
def __str__(self) -> str: def __str__(self) -> str:
return f"Flow {self.name} ({self.slug})" return f"Flow {self.name} ({self.slug})"

View File

@ -11,7 +11,6 @@ from passbook.core.models import User
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from passbook.flows.models import Flow, Stage from passbook.flows.models import Flow, Stage
from passbook.policies.engine import PolicyEngine from passbook.policies.engine import PolicyEngine
from passbook.policies.types import PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
@ -52,22 +51,12 @@ class FlowPlanner:
self.use_cache = True self.use_cache = True
self.flow = flow self.flow = flow
def _check_flow_root_policies(self, request: HttpRequest) -> PolicyResult:
engine = PolicyEngine(self.flow, request.user, request)
engine.build()
return engine.result
def plan( def plan(
self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None
) -> FlowPlan: ) -> FlowPlan:
"""Check each of the flows' policies, check policies for each stage with PolicyBinding """Check each of the flows' policies, check policies for each stage with PolicyBinding
and return ordered list""" and return ordered list"""
LOGGER.debug("f(plan): Starting planning process", flow=self.flow) LOGGER.debug("f(plan): Starting planning process", flow=self.flow)
# First off, check the flow's direct policy bindings
# to make sure the user even has access to the flow
root_result = self._check_flow_root_policies(request)
if not root_result.passing:
raise FlowNonApplicableException(*root_result.messages)
# Bit of a workaround here, if there is a pending user set in the default context # Bit of a workaround here, if there is a pending user set in the default context
# we use that user for our cache key # we use that user for our cache key
# to make sure they don't get the generic response # to make sure they don't get the generic response
@ -75,6 +64,16 @@ class FlowPlanner:
user = default_context[PLAN_CONTEXT_PENDING_USER] user = default_context[PLAN_CONTEXT_PENDING_USER]
else: else:
user = request.user user = request.user
# First off, check the flow's direct policy bindings
# to make sure the user even has access to the flow
engine = PolicyEngine(self.flow, user, request)
if default_context:
engine.request.context = default_context
engine.build()
result = engine.result
if not result.passing:
raise FlowNonApplicableException(result.messages)
# User is passing so far, check if we have a cached plan
cached_plan_key = cache_key(self.flow, user) cached_plan_key = cache_key(self.flow, user)
cached_plan = cache.get(cached_plan_key, None) cached_plan = cache.get(cached_plan_key, None)
if cached_plan and self.use_cache: if cached_plan and self.use_cache:
@ -82,6 +81,7 @@ class FlowPlanner:
"f(plan): Taking plan from cache", flow=self.flow, key=cached_plan_key "f(plan): Taking plan from cache", flow=self.flow, key=cached_plan_key
) )
return cached_plan return cached_plan
LOGGER.debug("f(plan): building plan", flow=self.flow)
plan = self._build_plan(user, request, default_context) plan = self._build_plan(user, request, default_context)
cache.set(cache_key(self.flow, user), plan) cache.set(cache_key(self.flow, user), plan)
if not plan.stages: if not plan.stages:

View File

@ -1,5 +1,5 @@
"""flow planner tests""" """flow planner tests"""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, PropertyMock, patch
from django.core.cache import cache from django.core.cache import cache
from django.shortcuts import reverse from django.shortcuts import reverse
@ -13,7 +13,7 @@ from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache
from passbook.policies.types import PolicyResult from passbook.policies.types import PolicyResult
from passbook.stages.dummy.models import DummyStage from passbook.stages.dummy.models import DummyStage
POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False)) POLICY_RESULT_MOCK = PropertyMock(return_value=PolicyResult(False))
TIME_NOW_MOCK = MagicMock(return_value=3) TIME_NOW_MOCK = MagicMock(return_value=3)
@ -40,8 +40,7 @@ class TestFlowPlanner(TestCase):
planner.plan(request) planner.plan(request)
@patch( @patch(
"passbook.flows.planner.FlowPlanner._check_flow_root_policies", "passbook.policies.engine.PolicyEngine.result", POLICY_RESULT_MOCK,
POLICY_RESULT_MOCK,
) )
def test_non_applicable_plan(self): def test_non_applicable_plan(self):
"""Test that empty plan raises exception""" """Test that empty plan raises exception"""

View File

@ -1,5 +1,5 @@
"""flow views tests""" """flow views tests"""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, PropertyMock, patch
from django.shortcuts import reverse from django.shortcuts import reverse
from django.test import Client, TestCase from django.test import Client, TestCase
@ -12,7 +12,7 @@ from passbook.lib.config import CONFIG
from passbook.policies.types import PolicyResult from passbook.policies.types import PolicyResult
from passbook.stages.dummy.models import DummyStage from passbook.stages.dummy.models import DummyStage
POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False)) POLICY_RESULT_MOCK = PropertyMock(return_value=PolicyResult(False))
class TestFlowExecutor(TestCase): class TestFlowExecutor(TestCase):
@ -45,8 +45,7 @@ class TestFlowExecutor(TestCase):
self.assertEqual(cancel_mock.call_count, 1) self.assertEqual(cancel_mock.call_count, 1)
@patch( @patch(
"passbook.flows.planner.FlowPlanner._check_flow_root_policies", "passbook.policies.engine.PolicyEngine.result", POLICY_RESULT_MOCK,
POLICY_RESULT_MOCK,
) )
def test_invalid_non_applicable_flow(self): def test_invalid_non_applicable_flow(self):
"""Tests that a non-applicable flow returns the correct error message""" """Tests that a non-applicable flow returns the correct error message"""

View File

@ -1,7 +1,7 @@
"""passbook multi-stage authentication engine""" """passbook multi-stage authentication engine"""
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from django.http import HttpRequest, HttpResponse from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, redirect, reverse from django.shortcuts import get_object_or_404, redirect, reverse
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.decorators.clickjacking import xframe_options_sameorigin from django.views.decorators.clickjacking import xframe_options_sameorigin
@ -164,7 +164,9 @@ class ToDefaultFlow(View):
designation: Optional[FlowDesignation] = None designation: Optional[FlowDesignation] = None
def dispatch(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
flow = get_object_or_404(Flow, designation=self.designation) flow = Flow.with_policy(request, designation=self.designation)
if not flow:
raise Http404
# If user already has a pending plan, clear it so we don't have to later. # If user already has a pending plan, clear it so we don't have to later.
if SESSION_KEY_PLAN in self.request.session: if SESSION_KEY_PLAN in self.request.session:
plan: FlowPlan = self.request.session[SESSION_KEY_PLAN] plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]

View File

@ -1,4 +1,6 @@
"""passbook policies app config""" """passbook policies app config"""
from importlib import import_module
from django.apps import AppConfig from django.apps import AppConfig
@ -8,3 +10,7 @@ class PassbookPoliciesConfig(AppConfig):
name = "passbook.policies" name = "passbook.policies"
label = "passbook_policies" label = "passbook_policies"
verbose_name = "passbook Policies" verbose_name = "passbook Policies"
def ready(self):
"""Load source_types from config file"""
import_module("passbook.policies.signals")

View File

@ -73,16 +73,20 @@ class PolicyEngine:
"""Build task group""" """Build task group"""
for binding in self._iter_bindings(): for binding in self._iter_bindings():
self._check_policy_type(binding.policy) self._check_policy_type(binding.policy)
policy = binding.policy key = cache_key(binding, self.request)
cached_policy = cache.get(cache_key(binding, self.request.user), None) cached_policy = cache.get(key, None)
if cached_policy and self.use_cache: if cached_policy and self.use_cache:
LOGGER.debug("P_ENG: Taking result from cache", policy=policy) LOGGER.debug(
"P_ENG: Taking result from cache",
policy=binding.policy,
cache_key=key,
)
self.__cached_policies.append(cached_policy) self.__cached_policies.append(cached_policy)
continue continue
LOGGER.debug("P_ENG: Evaluating policy", policy=policy) LOGGER.debug("P_ENG: Evaluating policy", policy=binding.policy)
our_end, task_end = Pipe(False) our_end, task_end = Pipe(False)
task = PolicyProcess(binding, self.request, task_end) task = PolicyProcess(binding, self.request, task_end)
LOGGER.debug("P_ENG: Starting Process", policy=policy) LOGGER.debug("P_ENG: Starting Process", policy=binding.policy)
task.start() task.start()
self.__processes.append( self.__processes.append(
PolicyProcessInfo(process=task, connection=our_end, binding=binding) PolicyProcessInfo(process=task, connection=our_end, binding=binding)
@ -103,7 +107,9 @@ class PolicyEngine:
x.result for x in self.__processes if x.result x.result for x in self.__processes if x.result
] ]
for result in process_results + self.__cached_policies: for result in process_results + self.__cached_policies:
LOGGER.debug("P_ENG: result", passing=result.passing) LOGGER.debug(
"P_ENG: result", passing=result.passing, messages=result.messages
)
if result.messages: if result.messages:
messages += result.messages messages += result.messages
if not result.passing: if not result.passing:

View File

@ -6,7 +6,6 @@ from typing import Optional
from django.core.cache import cache from django.core.cache import cache
from structlog import get_logger from structlog import get_logger
from passbook.core.models import User
from passbook.policies.exceptions import PolicyException from passbook.policies.exceptions import PolicyException
from passbook.policies.models import PolicyBinding from passbook.policies.models import PolicyBinding
from passbook.policies.types import PolicyRequest, PolicyResult from passbook.policies.types import PolicyRequest, PolicyResult
@ -14,11 +13,13 @@ from passbook.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str: def cache_key(binding: PolicyBinding, request: PolicyRequest) -> str:
"""Generate Cache key for policy""" """Generate Cache key for policy"""
prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}" prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}"
if user: if request.http_request:
prefix += f"#{user.pk}" prefix += f"_{request.http_request.session.session_key}"
if request.user:
prefix += f"#{request.user.pk}"
return prefix return prefix
@ -65,7 +66,7 @@ class PolicyProcess(Process):
passing=policy_result.passing, passing=policy_result.passing,
user=self.request.user, user=self.request.user,
) )
key = cache_key(self.binding, self.request.user) key = cache_key(self.binding, self.request)
cache.set(key, policy_result) cache.set(key, policy_result)
LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key) LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key)
return policy_result return policy_result

View File

@ -0,0 +1,26 @@
"""passbook policy signals"""
from django.core.cache import cache
from django.db.models.signals import post_save
from django.dispatch import receiver
from structlog import get_logger
LOGGER = get_logger()
@receiver(post_save)
# pylint: disable=unused-argument
def invalidate_policy_cache(sender, instance, **_):
"""Invalidate Policy cache when policy is updated"""
from passbook.policies.models import Policy, PolicyBinding
if isinstance(instance, Policy):
LOGGER.debug("Invalidating policy cache", policy=instance)
total = 0
for binding in PolicyBinding.objects.filter(policy=instance):
prefix = (
f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}*"
)
keys = cache.keys(prefix)
total += len(keys)
cache.delete_many(keys)
LOGGER.debug("Deleted keys", len=total)

View File

@ -0,0 +1,63 @@
# Generated by Django 3.0.6 on 2020-05-23 19:32
from django.db import migrations
def create_default_property_mappings(apps, schema_editor):
"""Create default SAML Property Mappings"""
SAMLPropertyMapping = apps.get_model(
"passbook_providers_saml", "SAMLPropertyMapping"
)
db_alias = schema_editor.connection.alias
defaults = [
{
"FriendlyName": "eduPersonPrincipalName",
"Name": "urn:oid:1.3.6.1.4.1.5923.1.1.1.6",
"Expression": "{{ user.email }}",
},
{
"FriendlyName": "cn",
"Name": "urn:oid:2.5.4.3",
"Expression": "{{ user.name }}",
},
{
"FriendlyName": "mail",
"Name": "urn:oid:0.9.2342.19200300.100.1.3",
"Expression": "{{ user.email }}",
},
{
"FriendlyName": "displayName",
"Name": "urn:oid:2.16.840.1.113730.3.1.241",
"Expression": "{{ user.username }}",
},
{
"FriendlyName": "uid",
"Name": "urn:oid:0.9.2342.19200300.100.1.1",
"Expression": "{{ user.pk }}",
},
{
"FriendlyName": "member-of",
"Name": "member-of",
"Expression": "[{% for group in user.groups.all() %}'{{ group.name }}',{% endfor %}]",
},
]
for default in defaults:
SAMLPropertyMapping.objects.using(db_alias).get_or_create(
saml_name=default["Name"],
friendly_name=default["FriendlyName"],
expression=default["Expression"],
defaults={
"name": f"Autogenerated SAML Mapping: {default['FriendlyName']} -> {default['Expression']}"
},
)
class Migration(migrations.Migration):
dependencies = [
("passbook_providers_saml", "0001_initial"),
]
operations = [
migrations.RunPython(create_default_property_mappings),
]

View File

@ -23,6 +23,7 @@ class LDAPSourceSerializer(ModelSerializer):
"group_object_filter", "group_object_filter",
"user_group_membership_field", "user_group_membership_field",
"object_uniqueness_field", "object_uniqueness_field",
"sync_users",
"sync_groups", "sync_groups",
"sync_parent_group", "sync_parent_group",
"property_mappings", "property_mappings",

View File

@ -16,26 +16,10 @@ LOGGER = get_logger()
class Connector: class Connector:
"""Wrapper for ldap3 to easily manage user authentication and creation""" """Wrapper for ldap3 to easily manage user authentication and creation"""
_server: ldap3.Server
_connection = ldap3.Connection
_source: LDAPSource _source: LDAPSource
def __init__(self, source: LDAPSource): def __init__(self, source: LDAPSource):
self._source = source self._source = source
self._server = ldap3.Server(source.server_uri) # Implement URI parsing
def bind(self):
"""Bind using Source's Credentials"""
self._connection = ldap3.Connection(
self._server,
raise_exceptions=True,
user=self._source.bind_cn,
password=self._source.bind_password,
)
self._connection.bind()
if self._source.start_tls:
self._connection.start_tls()
@staticmethod @staticmethod
def encode_pass(password: str) -> bytes: def encode_pass(password: str) -> bytes:
@ -45,19 +29,23 @@ class Connector:
@property @property
def base_dn_users(self) -> str: def base_dn_users(self) -> str:
"""Shortcut to get full base_dn for user lookups""" """Shortcut to get full base_dn for user lookups"""
return ",".join([self._source.additional_user_dn, self._source.base_dn]) if self._source.additional_user_dn:
return f"{self._source.additional_user_dn},{self._source.base_dn}"
return self._source.base_dn
@property @property
def base_dn_groups(self) -> str: def base_dn_groups(self) -> str:
"""Shortcut to get full base_dn for group lookups""" """Shortcut to get full base_dn for group lookups"""
return ",".join([self._source.additional_group_dn, self._source.base_dn]) if self._source.additional_group_dn:
return f"{self._source.additional_group_dn},{self._source.base_dn}"
return self._source.base_dn
def sync_groups(self): def sync_groups(self):
"""Iterate over all LDAP Groups and create passbook_core.Group instances""" """Iterate over all LDAP Groups and create passbook_core.Group instances"""
if not self._source.sync_groups: if not self._source.sync_groups:
LOGGER.debug("Group syncing is disabled for this Source") LOGGER.warning("Group syncing is disabled for this Source")
return return
groups = self._connection.extend.standard.paged_search( groups = self._source.connection.extend.standard.paged_search(
search_base=self.base_dn_groups, search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter, search_filter=self._source.group_object_filter,
search_scope=ldap3.SUBTREE, search_scope=ldap3.SUBTREE,
@ -87,7 +75,10 @@ class Connector:
def sync_users(self): def sync_users(self):
"""Iterate over all LDAP Users and create passbook_core.User instances""" """Iterate over all LDAP Users and create passbook_core.User instances"""
users = self._connection.extend.standard.paged_search( if not self._source.sync_users:
LOGGER.warning("User syncing is disabled for this Source")
return
users = self._source.connection.extend.standard.paged_search(
search_base=self.base_dn_users, search_base=self.base_dn_users,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,
search_scope=ldap3.SUBTREE, search_scope=ldap3.SUBTREE,
@ -101,9 +92,9 @@ class Connector:
LOGGER.warning("Cannot find uniqueness Field in attributes") LOGGER.warning("Cannot find uniqueness Field in attributes")
continue continue
try: try:
defaults = self._build_object_properties(attributes)
user, created = User.objects.update_or_create( user, created = User.objects.update_or_create(
attributes__ldap_uniq=uniq, attributes__ldap_uniq=uniq, defaults=defaults,
defaults=self._build_object_properties(attributes),
) )
except IntegrityError as exc: except IntegrityError as exc:
LOGGER.warning("Failed to create user", exc=exc) LOGGER.warning("Failed to create user", exc=exc)
@ -123,7 +114,7 @@ class Connector:
def sync_membership(self): def sync_membership(self):
"""Iterate over all Users and assign Groups using memberOf Field""" """Iterate over all Users and assign Groups using memberOf Field"""
users = self._connection.extend.standard.paged_search( users = self._source.connection.extend.standard.paged_search(
search_base=self.base_dn_users, search_base=self.base_dn_users,
search_filter=self._source.user_object_filter, search_filter=self._source.user_object_filter,
search_scope=ldap3.SUBTREE, search_scope=ldap3.SUBTREE,
@ -220,7 +211,7 @@ class Connector:
LOGGER.debug("Attempting Binding as user", user=user) LOGGER.debug("Attempting Binding as user", user=user)
try: try:
temp_connection = ldap3.Connection( temp_connection = ldap3.Connection(
self._server, self._source.connection.server,
user=user.attributes.get("distinguishedName"), user=user.attributes.get("distinguishedName"),
password=password, password=password,
raise_exceptions=True, raise_exceptions=True,

View File

@ -26,6 +26,7 @@ class LDAPSourceForm(forms.ModelForm):
"group_object_filter", "group_object_filter",
"user_group_membership_field", "user_group_membership_field",
"object_uniqueness_field", "object_uniqueness_field",
"sync_users",
"sync_groups", "sync_groups",
"sync_parent_group", "sync_parent_group",
"property_mappings", "property_mappings",

View File

@ -0,0 +1,18 @@
# Generated by Django 3.0.6 on 2020-05-23 19:17
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("passbook_sources_ldap", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="ldapsource",
name="sync_users",
field=models.BooleanField(default=True),
),
]

View File

@ -0,0 +1,35 @@
# Generated by Django 3.0.6 on 2020-05-23 19:30
from django.apps.registry import Apps
from django.db import migrations
def create_default_ad_property_mappings(apps: Apps, schema_editor):
LDAPPropertyMapping = apps.get_model("passbook_sources_ldap", "LDAPPropertyMapping")
mapping = {
"name": "{{ ldap.name }}",
"first_name": "{{ ldap.givenName }}",
"last_name": "{{ ldap.sn }}",
"username": "{{ ldap.sAMAccountName }}",
"email": "{{ ldap.mail }}",
}
db_alias = schema_editor.connection.alias
for object_field, expression in mapping.items():
LDAPPropertyMapping.objects.using(db_alias).get_or_create(
expression=expression,
object_field=object_field,
defaults={
"name": f"Autogenerated LDAP Mapping: {expression} -> {object_field}"
},
)
class Migration(migrations.Migration):
dependencies = [
("passbook_sources_ldap", "0002_ldapsource_sync_users"),
]
operations = [
migrations.RunPython(create_default_ad_property_mappings),
]

View File

@ -1,8 +1,10 @@
"""passbook LDAP Models""" """passbook LDAP Models"""
from typing import Optional
from django.core.validators import URLValidator from django.core.validators import URLValidator
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from ldap3 import Connection, Server
from passbook.core.models import Group, PropertyMapping, Source from passbook.core.models import Group, PropertyMapping, Source
@ -22,10 +24,12 @@ class LDAPSource(Source):
additional_user_dn = models.TextField( additional_user_dn = models.TextField(
help_text=_("Prepended to Base DN for User-queries."), help_text=_("Prepended to Base DN for User-queries."),
verbose_name=_("Addition User DN"), verbose_name=_("Addition User DN"),
blank=True,
) )
additional_group_dn = models.TextField( additional_group_dn = models.TextField(
help_text=_("Prepended to Base DN for Group-queries."), help_text=_("Prepended to Base DN for Group-queries."),
verbose_name=_("Addition Group DN"), verbose_name=_("Addition Group DN"),
blank=True,
) )
user_object_filter = models.TextField( user_object_filter = models.TextField(
@ -43,6 +47,7 @@ class LDAPSource(Source):
default="objectSid", help_text=_("Field which contains a unique Identifier.") default="objectSid", help_text=_("Field which contains a unique Identifier.")
) )
sync_users = models.BooleanField(default=True)
sync_groups = models.BooleanField(default=True) sync_groups = models.BooleanField(default=True)
sync_parent_group = models.ForeignKey( sync_parent_group = models.ForeignKey(
Group, blank=True, null=True, default=None, on_delete=models.SET_DEFAULT Group, blank=True, null=True, default=None, on_delete=models.SET_DEFAULT
@ -50,6 +55,25 @@ class LDAPSource(Source):
form = "passbook.sources.ldap.forms.LDAPSourceForm" form = "passbook.sources.ldap.forms.LDAPSourceForm"
_connection: Optional[Connection]
@property
def connection(self) -> Connection:
"""Get a fully connected and bound LDAP Connection"""
if not self._connection:
server = Server(self.server_uri)
self._connection = Connection(
server,
raise_exceptions=True,
user=self.bind_cn,
password=self.bind_password,
)
self._connection.bind()
if self.start_tls:
self._connection.start_tls()
return self._connection
class Meta: class Meta:
verbose_name = _("LDAP Source") verbose_name = _("LDAP Source")

View File

@ -9,7 +9,6 @@ def sync_groups(source_pk: int):
"""Sync LDAP Groups on background worker""" """Sync LDAP Groups on background worker"""
source = LDAPSource.objects.get(pk=source_pk) source = LDAPSource.objects.get(pk=source_pk)
connector = Connector(source) connector = Connector(source)
connector.bind()
connector.sync_groups() connector.sync_groups()
@ -18,7 +17,6 @@ def sync_users(source_pk: int):
"""Sync LDAP Users on background worker""" """Sync LDAP Users on background worker"""
source = LDAPSource.objects.get(pk=source_pk) source = LDAPSource.objects.get(pk=source_pk)
connector = Connector(source) connector = Connector(source)
connector.bind()
connector.sync_users() connector.sync_users()
@ -27,7 +25,6 @@ def sync():
"""Sync all sources""" """Sync all sources"""
for source in LDAPSource.objects.filter(enabled=True): for source in LDAPSource.objects.filter(enabled=True):
connector = Connector(source) connector = Connector(source)
connector.bind()
connector.sync_users() connector.sync_users()
connector.sync_groups() connector.sync_groups()
connector.sync_membership() connector.sync_membership()

View File

@ -0,0 +1,75 @@
"""LDAP Source tests"""
from unittest.mock import PropertyMock, patch
from django.test import TestCase
from ldap3 import MOCK_SYNC, OFFLINE_AD_2012_R2, Connection, Server
from passbook.core.models import User
from passbook.sources.ldap.connector import Connector
from passbook.sources.ldap.models import LDAPPropertyMapping, LDAPSource
def _build_mock_connection() -> Connection:
"""Create mock connection"""
server = Server("my_fake_server", get_info=OFFLINE_AD_2012_R2)
_pass = "foo" # noqa # nosec
connection = Connection(
server,
user="cn=my_user,ou=test,o=lab",
password=_pass,
client_strategy=MOCK_SYNC,
)
connection.strategy.add_entry(
"cn=user0,ou=test,o=lab",
{
"userPassword": "test0000",
"sAMAccountName": "user0_sn",
"revision": 0,
"objectSid": "unique-test0000",
"objectCategory": "Person",
},
)
connection.strategy.add_entry(
"cn=user1,ou=test,o=lab",
{
"userPassword": "test1111",
"sAMAccountName": "user1_sn",
"revision": 0,
"objectSid": "unique-test1111",
"objectCategory": "Person",
},
)
connection.strategy.add_entry(
"cn=user2,ou=test,o=lab",
{
"userPassword": "test2222",
"sAMAccountName": "user2_sn",
"revision": 0,
"objectSid": "unique-test2222",
"objectCategory": "Person",
},
)
connection.bind()
return connection
LDAP_CONNECTION_PATCH = PropertyMock(return_value=_build_mock_connection())
class LDAPSourceTests(TestCase):
"""LDAP Source tests"""
def setUp(self):
self.source = LDAPSource.objects.create(
name="ldap", slug="ldap", base_dn="o=lab"
)
self.source.property_mappings.set(LDAPPropertyMapping.objects.all())
self.source.save()
@patch("passbook.sources.ldap.models.LDAPSource.connection", LDAP_CONNECTION_PATCH)
def test_sync_users(self):
"""Test user sync"""
connector = Connector(self.source)
connector.sync_users()
user = User.objects.filter(username="user2_sn")
self.assertTrue(user.exists())

View File

@ -1,6 +1,6 @@
"""OAuth Clients""" """OAuth Clients"""
import json import json
from typing import Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import parse_qs, urlencode from urllib.parse import parse_qs, urlencode
from django.http import HttpRequest from django.http import HttpRequest
@ -14,24 +14,29 @@ from structlog import get_logger
from passbook import __version__ from passbook import __version__
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING:
from passbook.sources.oauth.models import OAuthSource
class BaseOAuthClient: class BaseOAuthClient:
"""Base OAuth Client""" """Base OAuth Client"""
session: Session session: Session
source: "OAuthSource"
def __init__(self, source, token=""): # nosec def __init__(self, source: "OAuthSource", token=""): # nosec
self.source = source self.source = source
self.token = token self.token = token
self.session = Session() self.session = Session()
self.session.headers.update({"User-Agent": "passbook %s" % __version__}) self.session.headers.update({"User-Agent": "passbook %s" % __version__})
def get_access_token(self, request, callback=None): def get_access_token(
self, request: HttpRequest, callback=None
) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_profile_info(self, token: Dict[str, str]): def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
"Fetch user profile information." "Fetch user profile information."
try: try:
headers = { headers = {
@ -45,7 +50,7 @@ class BaseOAuthClient:
LOGGER.warning("Unable to fetch user profile", exc=exc) LOGGER.warning("Unable to fetch user profile", exc=exc)
return None return None
else: else:
return response.json() or response.text return response.json()
def get_redirect_args(self, request, callback) -> Dict[str, str]: def get_redirect_args(self, request, callback) -> Dict[str, str]:
"Get request parameters for redirect url." "Get request parameters for redirect url."

View File

@ -21,7 +21,7 @@ from passbook.flows.planner import (
) )
from passbook.flows.views import SESSION_KEY_PLAN from passbook.flows.views import SESSION_KEY_PLAN
from passbook.lib.utils.urls import redirect_with_qs from passbook.lib.utils.urls import redirect_with_qs
from passbook.sources.oauth.clients import get_client from passbook.sources.oauth.clients import BaseOAuthClient, get_client
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND from passbook.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
@ -34,7 +34,7 @@ class OAuthClientMixin:
client_class: Optional[Callable] = None client_class: Optional[Callable] = None
def get_client(self, source): def get_client(self, source: OAuthSource) -> BaseOAuthClient:
"Get instance of the OAuth client for this source." "Get instance of the OAuth client for this source."
if self.client_class is not None: if self.client_class is not None:
# pylint: disable=not-callable # pylint: disable=not-callable

View File

@ -16,7 +16,7 @@ class IdentificationStageForm(forms.ModelForm):
class Meta: class Meta:
model = IdentificationStage model = IdentificationStage
fields = ["name", "user_fields", "template"] fields = ["name", "user_fields", "template", "enrollment_flow", "recovery_flow"]
widgets = { widgets = {
"name": forms.TextInput(), "name": forms.TextInput(),
} }

View File

@ -72,6 +72,7 @@ class TestUserWriteStage(TestCase):
plan.context[PLAN_CONTEXT_PROMPT] = { plan.context[PLAN_CONTEXT_PROMPT] = {
"username": "test-user-new", "username": "test-user-new",
"password": new_password, "password": new_password,
"some-custom-attribute": "test",
} }
session = self.client.session session = self.client.session
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
@ -88,6 +89,7 @@ class TestUserWriteStage(TestCase):
) )
self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.exists())
self.assertTrue(user_qs.first().check_password(new_password)) self.assertTrue(user_qs.first().check_password(new_password))
self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test")
def test_without_data(self): def test_without_data(self):
"""Test without data results in error""" """Test without data results in error"""

View File

@ -5606,8 +5606,6 @@ definitions:
- bind_cn - bind_cn
- bind_password - bind_password
- base_dn - base_dn
- additional_user_dn
- additional_group_dn
type: object type: object
properties: properties:
pk: pk:
@ -5654,12 +5652,10 @@ definitions:
title: Addition User DN title: Addition User DN
description: Prepended to Base DN for User-queries. description: Prepended to Base DN for User-queries.
type: string type: string
minLength: 1
additional_group_dn: additional_group_dn:
title: Addition Group DN title: Addition Group DN
description: Prepended to Base DN for Group-queries. description: Prepended to Base DN for Group-queries.
type: string type: string
minLength: 1
user_object_filter: user_object_filter:
title: User object filter title: User object filter
description: Consider Objects matching this filter to be Users. description: Consider Objects matching this filter to be Users.
@ -5680,6 +5676,9 @@ definitions:
description: Field which contains a unique Identifier. description: Field which contains a unique Identifier.
type: string type: string
minLength: 1 minLength: 1
sync_users:
title: Sync users
type: boolean
sync_groups: sync_groups:
title: Sync groups title: Sync groups
type: boolean type: boolean