get enterprise token

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-12-16 22:56:31 +01:00
parent 0d45b36cf2
commit fe05ea6048
No known key found for this signature in database
2 changed files with 26 additions and 4 deletions

View File

@ -88,7 +88,7 @@ class LicenseKey:
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
active_licenses = License.non_expired()
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
total.internal_users += lic.internal_users
@ -167,6 +167,10 @@ class License(SerializerModel):
internal_users = models.BigIntegerField()
external_users = models.BigIntegerField()
@classmethod
def non_expired(cls) -> QuerySet["License"]:
return License.objects.filter(expiry__gte=now())
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.enterprise.api import LicenseSerializer

View File

@ -8,6 +8,8 @@ from grpc import (
UnaryUnaryClientInterceptor,
insecure_channel,
intercept_channel,
ssl_channel_credentials,
secure_channel,
)
from grpc._interceptor import _ClientCallDetails
@ -48,12 +50,28 @@ class AuthInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor)
return continuation(self._intercept_client_call_details(client_call_details), request)
@lru_cache()
def get_enterprise_token() -> str:
"""Get enterprise license key, if a license is installed, otherwise use the install ID"""
from authentik.root.install_id import get_install_id
try:
from authentik.enterprise.models import License
license = License.non_expired().order_by("-expiry").first()
if not license:
return get_install_id()
return license.key
except ImportError:
return get_install_id()
@lru_cache()
def get_client(addr: str):
"""get a cached client to a cloud-gateway"""
target = addr
channel = secure_channel(addr, ssl_channel_credentials)
if settings.DEBUG:
target = insecure_channel(target)
channel = intercept_channel(target, AuthInterceptor("foo"))
channel = insecure_channel(addr)
channel = intercept_channel(addr, AuthInterceptor(get_enterprise_token()))
client = AuthenticationPushStub(channel)
return client