diff --git a/authentik/sources/oauth/api/source.py b/authentik/sources/oauth/api/source.py index 0d4b65029..13db82c1e 100644 --- a/authentik/sources/oauth/api/source.py +++ b/authentik/sources/oauth/api/source.py @@ -5,6 +5,7 @@ from rest_framework.decorators import action from rest_framework.fields import BooleanField, CharField, SerializerMethodField from rest_framework.request import Request from rest_framework.response import Response +from rest_framework.serializers import ValidationError from rest_framework.viewsets import ModelViewSet from authentik.core.api.sources import SourceSerializer @@ -47,6 +48,20 @@ class OAuthSourceSerializer(SourceSerializer): """Get source's type configuration""" return SourceTypeSerializer(instace.type).data + def validate(self, attrs: dict) -> dict: + provider_type = MANAGER.find_type(attrs.get("provider_type", "")) + for url in [ + "authorization_url", + "access_token_url", + "profile_url", + ]: + if getattr(provider_type, url, None) is None: + if url not in attrs: + raise ValidationError( + f"{url} is required for provider {provider_type.name}" + ) + return attrs + class Meta: model = OAuthSource fields = SourceSerializer.Meta.fields + [ diff --git a/authentik/sources/oauth/migrations/0003_auto_20210416_0726.py b/authentik/sources/oauth/migrations/0003_auto_20210416_0726.py new file mode 100644 index 000000000..05ad87d2b --- /dev/null +++ b/authentik/sources/oauth/migrations/0003_auto_20210416_0726.py @@ -0,0 +1,43 @@ +# Generated by Django 3.2 on 2021-04-16 07:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_sources_oauth", "0002_auto_20200520_1108"), + ] + + operations = [ + migrations.AlterField( + model_name="oauthsource", + name="access_token_url", + field=models.CharField( + blank=True, + help_text="URL used by authentik to retrive tokens.", + max_length=255, + verbose_name="Access Token URL", + ), + ), + migrations.AlterField( + model_name="oauthsource", + name="authorization_url", + field=models.CharField( + blank=True, + help_text="URL the user is redirect to to conest the flow.", + max_length=255, + verbose_name="Authorization URL", + ), + ), + migrations.AlterField( + model_name="oauthsource", + name="profile_url", + field=models.CharField( + blank=True, + help_text="URL used by authentik to get user information.", + max_length=255, + verbose_name="Profile URL", + ), + ), + ] diff --git a/authentik/sources/oauth/models.py b/authentik/sources/oauth/models.py index 8fc39f4de..19466a675 100644 --- a/authentik/sources/oauth/models.py +++ b/authentik/sources/oauth/models.py @@ -28,16 +28,19 @@ class OAuthSource(Source): ) authorization_url = models.CharField( max_length=255, + blank=True, verbose_name=_("Authorization URL"), help_text=_("URL the user is redirect to to conest the flow."), ) access_token_url = models.CharField( max_length=255, + blank=True, verbose_name=_("Access Token URL"), help_text=_("URL used by authentik to retrive tokens."), ) profile_url = models.CharField( max_length=255, + blank=True, verbose_name=_("Profile URL"), help_text=_("URL used by authentik to get user information."), ) @@ -49,7 +52,7 @@ class OAuthSource(Source): """Return the provider instance for this source""" from authentik.sources.oauth.types.manager import MANAGER - return MANAGER.find_type(self) + return MANAGER.find_type(self.provider_type) @property def component(self) -> str: diff --git a/authentik/sources/oauth/tests/test_views.py b/authentik/sources/oauth/tests/test_views.py index 9478902ab..48a7ef188 100644 --- a/authentik/sources/oauth/tests/test_views.py +++ b/authentik/sources/oauth/tests/test_views.py @@ -1,4 +1,5 @@ """OAuth Source tests""" +from authentik.sources.oauth.api.source import OAuthSourceSerializer from django.test import TestCase from django.urls import reverse @@ -18,6 +19,23 @@ class TestOAuthSource(TestCase): consumer_key="", ) + def test_api_validate(self): + """Test API validation""" + self.assertTrue(OAuthSourceSerializer(data={ + "name": "foo", + "slug": "bar", + "provider_type": "google", + "consumer_key": "foo", + "consumer_secret": "foo", + }).is_valid()) + self.assertFalse(OAuthSourceSerializer(data={ + "name": "foo", + "slug": "bar", + "provider_type": "openid-connect", + "consumer_key": "foo", + "consumer_secret": "foo", + }).is_valid()) + def test_source_redirect(self): """test redirect view""" self.client.get( diff --git a/authentik/sources/oauth/types/manager.py b/authentik/sources/oauth/types/manager.py index b1b921912..d58cd21c6 100644 --- a/authentik/sources/oauth/types/manager.py +++ b/authentik/sources/oauth/types/manager.py @@ -58,17 +58,17 @@ class SourceTypeManager: """Get list of tuples of all registered names""" return [(x.slug, x.name) for x in self.__sources] - def find_type(self, source: "OAuthSource") -> SourceType: + def find_type(self, type_name: str) -> SourceType: """Find type based on source""" found_type = None for src_type in self.__sources: - if src_type.slug == source.provider_type: + if src_type.slug == type_name: return src_type if not found_type: found_type = SourceType() LOGGER.warning( "no matching type found, using default", - wanted=source.provider_type, + wanted=type_name, have=[x.name for x in self.__sources], ) return found_type diff --git a/swagger.yaml b/swagger.yaml index 23424d8cd..7425bdce5 100755 --- a/swagger.yaml +++ b/swagger.yaml @@ -16963,9 +16963,6 @@ definitions: - name - slug - provider_type - - authorization_url - - access_token_url - - profile_url - consumer_key - consumer_secret type: object @@ -17037,19 +17034,16 @@ definitions: description: URL the user is redirect to to conest the flow. type: string maxLength: 255 - minLength: 1 access_token_url: title: Access Token URL description: URL used by authentik to retrive tokens. type: string maxLength: 255 - minLength: 1 profile_url: title: Profile URL description: URL used by authentik to get user information. type: string maxLength: 255 - minLength: 1 consumer_key: title: Consumer key type: string