From edc7f2fdb01f8660bf303e116fbc89e83cd05820 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Tue, 1 Aug 2023 20:21:19 +0200 Subject: [PATCH] separate blueprint importer from yaml parsing Signed-off-by: Jens Langhammer --- authentik/blueprints/api.py | 4 +- .../management/commands/apply_blueprint.py | 4 +- authentik/blueprints/tests/__init__.py | 4 +- authentik/blueprints/tests/test_packaged.py | 4 +- authentik/blueprints/tests/test_v1.py | 22 ++++---- .../tests/test_v1_conditional_fields.py | 4 +- .../blueprints/tests/test_v1_conditions.py | 6 +- authentik/blueprints/tests/test_v1_state.py | 14 ++--- authentik/blueprints/v1/importer.py | 55 +++++++++++-------- authentik/blueprints/v1/tasks.py | 4 +- authentik/flows/api/flows.py | 4 +- 11 files changed, 67 insertions(+), 58 deletions(-) diff --git a/authentik/blueprints/api.py b/authentik/blueprints/api.py index 4ae847106..d13a30812 100644 --- a/authentik/blueprints/api.py +++ b/authentik/blueprints/api.py @@ -12,7 +12,7 @@ from rest_framework.viewsets import ModelViewSet from authentik.api.decorators import permission_required from authentik.blueprints.models import BlueprintInstance -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict from authentik.core.api.used_by import UsedByMixin @@ -49,7 +49,7 @@ class BlueprintInstanceSerializer(ModelSerializer): if content == "": return content context = self.instance.context if self.instance else {} - valid, logs = Importer(content, context).validate() + valid, logs = StringImporter(content, context).validate() if not valid: text_logs = "\n".join([x["event"] for x in logs]) raise ValidationError(_("Failed to validate blueprint: %(logs)s" % {"logs": text_logs})) diff --git a/authentik/blueprints/management/commands/apply_blueprint.py b/authentik/blueprints/management/commands/apply_blueprint.py index 4aea0159d..427c13bb6 100644 --- a/authentik/blueprints/management/commands/apply_blueprint.py +++ b/authentik/blueprints/management/commands/apply_blueprint.py @@ -5,7 +5,7 @@ from django.core.management.base import BaseCommand, no_translations from structlog.stdlib import get_logger from authentik.blueprints.models import BlueprintInstance -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter LOGGER = get_logger() @@ -18,7 +18,7 @@ class Command(BaseCommand): """Apply all blueprints in order, abort when one fails to import""" for blueprint_path in options.get("blueprints", []): content = BlueprintInstance(path=blueprint_path).retrieve() - importer = Importer(content) + importer = StringImporter(content) valid, _ = importer.validate() if not valid: self.stderr.write("blueprint invalid") diff --git a/authentik/blueprints/tests/__init__.py b/authentik/blueprints/tests/__init__.py index 8b39ca6dd..06a3d04bf 100644 --- a/authentik/blueprints/tests/__init__.py +++ b/authentik/blueprints/tests/__init__.py @@ -11,7 +11,7 @@ from authentik.blueprints.models import BlueprintInstance def apply_blueprint(*files: str): """Apply blueprint before test""" - from authentik.blueprints.v1.importer import Importer + from authentik.blueprints.v1.importer import StringImporter def wrapper_outer(func: Callable): """Apply blueprint before test""" @@ -20,7 +20,7 @@ def apply_blueprint(*files: str): def wrapper(*args, **kwargs): for file in files: content = BlueprintInstance(path=file).retrieve() - Importer(content).apply() + StringImporter(content).apply() return func(*args, **kwargs) return wrapper diff --git a/authentik/blueprints/tests/test_packaged.py b/authentik/blueprints/tests/test_packaged.py index bb618fd75..5edf7afc3 100644 --- a/authentik/blueprints/tests/test_packaged.py +++ b/authentik/blueprints/tests/test_packaged.py @@ -6,7 +6,7 @@ from django.test import TransactionTestCase from authentik.blueprints.models import BlueprintInstance from authentik.blueprints.tests import apply_blueprint -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter from authentik.tenants.models import Tenant @@ -25,7 +25,7 @@ def blueprint_tester(file_name: Path) -> Callable: def tester(self: TestPackaged): base = Path("blueprints/") rel_path = Path(file_name).relative_to(base) - importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve()) + importer = StringImporter(BlueprintInstance(path=str(rel_path)).retrieve()) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) diff --git a/authentik/blueprints/tests/test_v1.py b/authentik/blueprints/tests/test_v1.py index c5136a1ba..ae10bb66a 100644 --- a/authentik/blueprints/tests/test_v1.py +++ b/authentik/blueprints/tests/test_v1.py @@ -4,7 +4,7 @@ from os import environ from django.test import TransactionTestCase from authentik.blueprints.v1.exporter import FlowExporter -from authentik.blueprints.v1.importer import Importer, transaction_rollback +from authentik.blueprints.v1.importer import StringImporter, transaction_rollback from authentik.core.models import Group from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding from authentik.lib.generators import generate_id @@ -21,14 +21,14 @@ class TestBlueprintsV1(TransactionTestCase): def test_blueprint_invalid_format(self): """Test blueprint with invalid format""" - importer = Importer('{"version": 3}') + importer = StringImporter('{"version": 3}') self.assertFalse(importer.validate()[0]) - importer = Importer( + importer = StringImporter( '{"version": 1,"entries":[{"identifiers":{},"attrs":{},' '"model": "authentik_core.User"}]}' ) self.assertFalse(importer.validate()[0]) - importer = Importer( + importer = StringImporter( '{"version": 1, "entries": [{"attrs": {"name": "test"}, ' '"identifiers": {}, ' '"model": "authentik_core.Group"}]}' @@ -54,7 +54,7 @@ class TestBlueprintsV1(TransactionTestCase): }, ) - importer = Importer( + importer = StringImporter( '{"version": 1, "entries": [{"attrs": {"name": "test999", "attributes": ' '{"key": ["updated_value"]}}, "identifiers": {"attributes": {"other_key": ' '["other_value"]}}, "model": "authentik_core.Group"}]}' @@ -103,7 +103,7 @@ class TestBlueprintsV1(TransactionTestCase): self.assertEqual(len(export.entries), 3) export_yaml = exporter.export_to_string() - importer = Importer(export_yaml) + importer = StringImporter(export_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) @@ -113,14 +113,14 @@ class TestBlueprintsV1(TransactionTestCase): """Test export and import it twice""" count_initial = Prompt.objects.filter(field_key="username").count() - importer = Importer(load_fixture("fixtures/static_prompt_export.yaml")) + importer = StringImporter(load_fixture("fixtures/static_prompt_export.yaml")) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) count_before = Prompt.objects.filter(field_key="username").count() self.assertEqual(count_initial + 1, count_before) - importer = Importer(load_fixture("fixtures/static_prompt_export.yaml")) + importer = StringImporter(load_fixture("fixtures/static_prompt_export.yaml")) self.assertTrue(importer.apply()) self.assertEqual(Prompt.objects.filter(field_key="username").count(), count_before) @@ -130,7 +130,7 @@ class TestBlueprintsV1(TransactionTestCase): ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").delete() Group.objects.filter(name="test").delete() environ["foo"] = generate_id() - importer = Importer(load_fixture("fixtures/tags.yaml"), {"bar": "baz"}) + importer = StringImporter(load_fixture("fixtures/tags.yaml"), {"bar": "baz"}) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) policy = ExpressionPolicy.objects.filter(name="foo-bar-baz-qux").first() @@ -247,7 +247,7 @@ class TestBlueprintsV1(TransactionTestCase): exporter = FlowExporter(flow) export_yaml = exporter.export_to_string() - importer = Importer(export_yaml) + importer = StringImporter(export_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) self.assertTrue(UserLoginStage.objects.filter(name=stage_name).exists()) @@ -296,7 +296,7 @@ class TestBlueprintsV1(TransactionTestCase): exporter = FlowExporter(flow) export_yaml = exporter.export_to_string() - importer = Importer(export_yaml) + importer = StringImporter(export_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) diff --git a/authentik/blueprints/tests/test_v1_conditional_fields.py b/authentik/blueprints/tests/test_v1_conditional_fields.py index a28083651..a8b880157 100644 --- a/authentik/blueprints/tests/test_v1_conditional_fields.py +++ b/authentik/blueprints/tests/test_v1_conditional_fields.py @@ -1,7 +1,7 @@ """Test blueprints v1""" from django.test import TransactionTestCase -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter from authentik.core.models import Application, Token, User from authentik.core.tests.utils import create_test_admin_user from authentik.flows.models import Flow @@ -18,7 +18,7 @@ class TestBlueprintsV1ConditionalFields(TransactionTestCase): self.uid = generate_id() import_yaml = load_fixture("fixtures/conditional_fields.yaml", uid=self.uid, user=user.pk) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) diff --git a/authentik/blueprints/tests/test_v1_conditions.py b/authentik/blueprints/tests/test_v1_conditions.py index 3914e19dd..57217c49e 100644 --- a/authentik/blueprints/tests/test_v1_conditions.py +++ b/authentik/blueprints/tests/test_v1_conditions.py @@ -1,7 +1,7 @@ """Test blueprints v1""" from django.test import TransactionTestCase -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter from authentik.flows.models import Flow from authentik.lib.generators import generate_id from authentik.lib.tests.utils import load_fixture @@ -18,7 +18,7 @@ class TestBlueprintsV1Conditions(TransactionTestCase): "fixtures/conditions_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2 ) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) # Ensure objects exist @@ -35,7 +35,7 @@ class TestBlueprintsV1Conditions(TransactionTestCase): "fixtures/conditions_not_fulfilled.yaml", id1=flow_slug1, id2=flow_slug2 ) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) # Ensure objects do not exist diff --git a/authentik/blueprints/tests/test_v1_state.py b/authentik/blueprints/tests/test_v1_state.py index e0bf4c13a..a14b7f5e8 100644 --- a/authentik/blueprints/tests/test_v1_state.py +++ b/authentik/blueprints/tests/test_v1_state.py @@ -1,7 +1,7 @@ """Test blueprints v1""" from django.test import TransactionTestCase -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter from authentik.flows.models import Flow from authentik.lib.generators import generate_id from authentik.lib.tests.utils import load_fixture @@ -15,7 +15,7 @@ class TestBlueprintsV1State(TransactionTestCase): flow_slug = generate_id() import_yaml = load_fixture("fixtures/state_present.yaml", id=flow_slug) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) # Ensure object exists @@ -30,7 +30,7 @@ class TestBlueprintsV1State(TransactionTestCase): self.assertEqual(flow.title, "bar") # Ensure importer updates it - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) flow: Flow = Flow.objects.filter(slug=flow_slug).first() @@ -41,7 +41,7 @@ class TestBlueprintsV1State(TransactionTestCase): flow_slug = generate_id() import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) # Ensure object exists @@ -56,7 +56,7 @@ class TestBlueprintsV1State(TransactionTestCase): self.assertEqual(flow.title, "bar") # Ensure importer doesn't update it - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) flow: Flow = Flow.objects.filter(slug=flow_slug).first() @@ -67,7 +67,7 @@ class TestBlueprintsV1State(TransactionTestCase): flow_slug = generate_id() import_yaml = load_fixture("fixtures/state_created.yaml", id=flow_slug) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) # Ensure object exists @@ -75,7 +75,7 @@ class TestBlueprintsV1State(TransactionTestCase): self.assertEqual(flow.slug, flow_slug) import_yaml = load_fixture("fixtures/state_absent.yaml", id=flow_slug) - importer = Importer(import_yaml) + importer = StringImporter(import_yaml) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) flow: Flow = Flow.objects.filter(slug=flow_slug).first() diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index e9b29938e..c528d6fd1 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -85,27 +85,22 @@ class Importer: """Import Blueprint from YAML""" logger: BoundLogger + _import: Blueprint - def __init__(self, yaml_input: str, context: Optional[dict] = None): + def __init__(self, blueprint: Blueprint, context: Optional[dict] = None): self.__pk_map: dict[Any, Model] = {} + self._import = blueprint self.logger = get_logger() - import_dict = load(yaml_input, BlueprintLoader) - try: - self.__import = from_dict( - Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState]) - ) - except DaciteError as exc: - raise EntryInvalidError from exc ctx = {} - always_merger.merge(ctx, self.__import.context) + always_merger.merge(ctx, self._import.context) if context: always_merger.merge(ctx, context) - self.__import.context = ctx + self._import.context = ctx @property def blueprint(self) -> Blueprint: """Get imported blueprint""" - return self.__import + return self._import def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]: """Replace any value if it is a known primary key of an other object""" @@ -151,11 +146,11 @@ class Importer: # pylint: disable-msg=too-many-locals def _validate_single(self, entry: BlueprintEntry) -> Optional[BaseSerializer]: """Validate a single entry""" - if not entry.check_all_conditions_match(self.__import): + if not entry.check_all_conditions_match(self._import): self.logger.debug("One or more conditions of this entry are not fulfilled, skipping") return None - model_app_label, model_name = entry.get_model(self.__import).split(".") + model_app_label, model_name = entry.get_model(self._import).split(".") model: type[SerializerModel] = registry.get_model(model_app_label, model_name) # Don't use isinstance since we don't want to check for inheritance if not is_model_allowed(model): @@ -163,7 +158,7 @@ class Importer: if issubclass(model, BaseMetaModel): serializer_class: type[Serializer] = model.serializer() serializer = serializer_class( - data=entry.get_attrs(self.__import), + data=entry.get_attrs(self._import), context={ SERIALIZER_CONTEXT_BLUEPRINT: entry, }, @@ -181,7 +176,7 @@ class Importer: # the full serializer for later usage # Because a model might have multiple unique columns, we chain all identifiers together # to create an OR query. - updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self.__import)) + updated_identifiers = self.__update_pks_for_attrs(entry.get_identifiers(self._import)) for key, value in list(updated_identifiers.items()): if isinstance(value, dict) and "pk" in value: del updated_identifiers[key] @@ -217,7 +212,7 @@ class Importer: model_instance.pk = updated_identifiers["pk"] serializer_kwargs["instance"] = model_instance try: - full_data = self.__update_pks_for_attrs(entry.get_attrs(self.__import)) + full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import)) except ValueError as exc: raise EntryInvalidError(exc) from exc always_merger.merge(full_data, updated_identifiers) @@ -252,8 +247,8 @@ class Importer: def _apply_models(self) -> bool: """Apply (create/update) models yaml""" self.__pk_map = {} - for entry in self.__import.entries: - model_app_label, model_name = entry.get_model(self.__import).split(".") + for entry in self._import.entries: + model_app_label, model_name = entry.get_model(self._import).split(".") try: model: type[SerializerModel] = registry.get_model(model_app_label, model_name) except LookupError: @@ -266,14 +261,14 @@ class Importer: serializer = self._validate_single(entry) except EntryInvalidError as exc: # For deleting objects we don't need the serializer to be valid - if entry.get_state(self.__import) == BlueprintEntryDesiredState.ABSENT: + if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: continue self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) return False if not serializer: continue - state = entry.get_state(self.__import) + state = entry.get_state(self._import) if state in [BlueprintEntryDesiredState.PRESENT, BlueprintEntryDesiredState.CREATED]: instance = serializer.instance if ( @@ -306,8 +301,8 @@ class Importer: """Validate loaded blueprint export, ensure all models are allowed and serializers have no errors""" self.logger.debug("Starting blueprint import validation") - orig_import = deepcopy(self.__import) - if self.__import.version != 1: + orig_import = deepcopy(self._import) + if self._import.version != 1: self.logger.warning("Invalid blueprint version") return False, [{"event": "Invalid blueprint version"}] with ( @@ -320,5 +315,19 @@ class Importer: for log in logs: getattr(self.logger, log.get("log_level"))(**log) self.logger.debug("Finished blueprint import validation") - self.__import = orig_import + self._import = orig_import return successful, logs + + +class StringImporter(Importer): + """Importer that also parses from string""" + + def __init__(self, yaml_input: str, context: dict | None = None): + import_dict = load(yaml_input, BlueprintLoader) + try: + _import = from_dict( + Blueprint, import_dict, config=Config(cast=[BlueprintEntryDesiredState]) + ) + except DaciteError as exc: + raise EntryInvalidError from exc + super().__init__(_import, context) diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index b63c0b144..40351dfc1 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -26,7 +26,7 @@ from authentik.blueprints.models import ( BlueprintRetrievalFailed, ) from authentik.blueprints.v1.common import BlueprintLoader, BlueprintMetadata, EntryInvalidError -from authentik.blueprints.v1.importer import Importer +from authentik.blueprints.v1.importer import StringImporter from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.events.monitored_tasks import ( @@ -190,7 +190,7 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str): self.set_uid(slugify(instance.name)) blueprint_content = instance.retrieve() file_hash = sha512(blueprint_content.encode()).hexdigest() - importer = Importer(blueprint_content, instance.context) + importer = StringImporter(blueprint_content, instance.context) if importer.blueprint.metadata: instance.metadata = asdict(importer.blueprint.metadata) valid, logs = importer.validate() diff --git a/authentik/flows/api/flows.py b/authentik/flows/api/flows.py index 07bf9f009..a12e28929 100644 --- a/authentik/flows/api/flows.py +++ b/authentik/flows/api/flows.py @@ -16,7 +16,7 @@ from structlog.stdlib import get_logger from authentik.api.decorators import permission_required from authentik.blueprints.v1.exporter import FlowExporter -from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer +from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, StringImporter from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import CacheSerializer, LinkSerializer, PassiveSerializer from authentik.events.utils import sanitize_dict @@ -181,7 +181,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): if not file: return Response(data=import_response.initial_data, status=400) - importer = Importer(file.read().decode()) + importer = StringImporter(file.read().decode()) valid, logs = importer.validate() import_response.initial_data["logs"] = [sanitize_dict(log) for log in logs] import_response.initial_data["success"] = valid