diff --git a/authentik/blueprints/management/commands/apply_blueprint.py b/authentik/blueprints/management/commands/apply_blueprint.py index 7b9081aae..8163b5844 100644 --- a/authentik/blueprints/management/commands/apply_blueprint.py +++ b/authentik/blueprints/management/commands/apply_blueprint.py @@ -18,7 +18,7 @@ class Command(BaseCommand): """Apply all blueprints in order, abort when one fails to import""" for blueprint_path in options.get("blueprints", []): content = BlueprintInstance(path=blueprint_path).retrieve() - importer = Importer(content) + importer = Importer(*content) valid, logs = importer.validate() if not valid: for log in logs: diff --git a/authentik/blueprints/models.py b/authentik/blueprints/models.py index 1a2171f75..a435a8fc7 100644 --- a/authentik/blueprints/models.py +++ b/authentik/blueprints/models.py @@ -70,7 +70,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): enabled = models.BooleanField(default=True) managed_models = ArrayField(models.TextField(), default=list) - def retrieve_oci(self) -> str: + def retrieve_oci(self) -> list[str]: """Get blueprint from an OCI registry""" client = BlueprintOCIClient(self.path.replace("oci://", "https://")) try: @@ -79,16 +79,16 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): except OCIException as exc: raise BlueprintRetrievalFailed(exc) from exc - def retrieve_file(self) -> str: + def retrieve_file(self) -> list[str]: """Get blueprint from path""" try: full_path = Path(CONFIG.y("blueprints_dir")).joinpath(Path(self.path)) with full_path.open("r", encoding="utf-8") as _file: - return _file.read() + return [_file.read()] except (IOError, OSError) as exc: raise BlueprintRetrievalFailed(exc) from exc - def retrieve(self) -> str: + def retrieve(self) -> list[str]: """Retrieve blueprint contents""" if self.path.startswith("oci://"): return self.retrieve_oci() diff --git a/authentik/blueprints/tests/__init__.py b/authentik/blueprints/tests/__init__.py index 4f62d668b..20eebacac 100644 --- a/authentik/blueprints/tests/__init__.py +++ b/authentik/blueprints/tests/__init__.py @@ -21,7 +21,7 @@ def apply_blueprint(*files: str): def wrapper(*args, **kwargs): for file in files: content = BlueprintInstance(path=file).retrieve() - Importer(content).apply() + Importer(*content).apply() return func(*args, **kwargs) return wrapper diff --git a/authentik/blueprints/tests/test_oci.py b/authentik/blueprints/tests/test_oci.py index dd54e2602..da47a2225 100644 --- a/authentik/blueprints/tests/test_oci.py +++ b/authentik/blueprints/tests/test_oci.py @@ -29,7 +29,7 @@ class TestBlueprintOCI(TransactionTestCase): BlueprintInstance( path="oci://ghcr.io/goauthentik/blueprints/test:latest" ).retrieve(), - "foo", + ["foo"], ) def test_manifests_error(self): diff --git a/authentik/blueprints/tests/test_packaged.py b/authentik/blueprints/tests/test_packaged.py index bb618fd75..0e17a56bf 100644 --- a/authentik/blueprints/tests/test_packaged.py +++ b/authentik/blueprints/tests/test_packaged.py @@ -25,7 +25,8 @@ def blueprint_tester(file_name: Path) -> Callable: def tester(self: TestPackaged): base = Path("blueprints/") rel_path = Path(file_name).relative_to(base) - importer = Importer(BlueprintInstance(path=str(rel_path)).retrieve()) + contents = BlueprintInstance(path=str(rel_path)).retrieve() + importer = Importer(*contents) self.assertTrue(importer.validate()[0]) self.assertTrue(importer.apply()) diff --git a/authentik/blueprints/v1/oci.py b/authentik/blueprints/v1/oci.py index b05d3c52f..064813699 100644 --- a/authentik/blueprints/v1/oci.py +++ b/authentik/blueprints/v1/oci.py @@ -75,22 +75,29 @@ class BlueprintOCIClient: raise OCIException(manifest["errors"]) return manifest - def fetch_blobs(self, manifest: dict[str, Any]): + def fetch_blobs(self, manifest: dict[str, Any]) -> list[str]: """Fetch blob based on manifest info""" - blob = None + blob_digests = [] for layer in manifest.get("layers", []): if layer.get("mediaType", "") == OCI_MEDIA_TYPE: - blob = layer.get("digest") - self.logger.debug("Found layer with matching media type", blob=blob) - if not blob: + blob_digests.append(layer.get("digest")) + if not blob_digests: raise OCIException("Blob not found") + bodies = [] + for blob in blob_digests: + bodies.append(self.fetch_blob(blob)) + self.logger.debug("Fetched blobs", count=len(bodies)) + return bodies + def fetch_blob(self, digest: str) -> str: + """Fetch blob based on manifest info""" blob_request = self.client.NewRequest( "GET", "/v2//blobs/", - WithDigest(blob), + WithDigest(digest), ) try: + self.logger.debug("Fetching blob", digest=digest) blob_response = self.client.Do(blob_request) blob_response.raise_for_status() return blob_response.text diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index 792811b14..bc3d585c3 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -185,8 +185,8 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str): if not instance or not instance.enabled: return blueprint_content = instance.retrieve() - file_hash = sha512(blueprint_content.encode()).hexdigest() - importer = Importer(blueprint_content, context=instance.context) + file_hash = sha512("".join(blueprint_content).encode()).hexdigest() + importer = Importer(*blueprint_content, context=instance.context) instance.metadata = importer.metadata valid, logs = importer.validate() if not valid: