root: use channel send workaround for sync sending of websocket messages
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
7f009f6d02
commit
bff34cc5dc
|
@ -7,7 +7,6 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from channels.layers import get_channel_layer
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
|
@ -43,6 +42,7 @@ from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesContro
|
||||||
from authentik.providers.proxy.controllers.docker import ProxyDockerController
|
from authentik.providers.proxy.controllers.docker import ProxyDockerController
|
||||||
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
|
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
|
from authentik.root.messages.storage import closing_send
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s"
|
CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s"
|
||||||
|
@ -217,26 +217,23 @@ def outpost_post_save(model_class: str, model_pk: Any):
|
||||||
def outpost_send_update(model_instace: Model):
|
def outpost_send_update(model_instace: Model):
|
||||||
"""Send outpost update to all registered outposts, regardless to which authentik
|
"""Send outpost update to all registered outposts, regardless to which authentik
|
||||||
instance they are connected"""
|
instance they are connected"""
|
||||||
channel_layer = get_channel_layer()
|
|
||||||
if isinstance(model_instace, OutpostModel):
|
if isinstance(model_instace, OutpostModel):
|
||||||
for outpost in model_instace.outpost_set.all():
|
for outpost in model_instace.outpost_set.all():
|
||||||
_outpost_single_update(outpost, channel_layer)
|
_outpost_single_update(outpost)
|
||||||
elif isinstance(model_instace, Outpost):
|
elif isinstance(model_instace, Outpost):
|
||||||
_outpost_single_update(model_instace, channel_layer)
|
_outpost_single_update(model_instace)
|
||||||
|
|
||||||
|
|
||||||
def _outpost_single_update(outpost: Outpost, layer=None):
|
def _outpost_single_update(outpost: Outpost):
|
||||||
"""Update outpost instances connected to a single outpost"""
|
"""Update outpost instances connected to a single outpost"""
|
||||||
# Ensure token again, because this function is called when anything related to an
|
# Ensure token again, because this function is called when anything related to an
|
||||||
# OutpostModel is saved, so we can be sure permissions are right
|
# OutpostModel is saved, so we can be sure permissions are right
|
||||||
_ = outpost.token
|
_ = outpost.token
|
||||||
outpost.build_user_permissions(outpost.user)
|
outpost.build_user_permissions(outpost.user)
|
||||||
if not layer: # pragma: no cover
|
|
||||||
layer = get_channel_layer()
|
|
||||||
for state in OutpostState.for_outpost(outpost):
|
for state in OutpostState.for_outpost(outpost):
|
||||||
for channel in state.channel_ids:
|
for channel in state.channel_ids:
|
||||||
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
|
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
|
||||||
async_to_sync(layer.send)(channel, {"type": "event.update"})
|
async_to_sync(closing_send)(channel, {"type": "event.update"})
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task()
|
@CELERY_APP.task()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Channels Messages storage"""
|
"""Channels Messages storage"""
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from channels.layers import get_channel_layer
|
from channels import DEFAULT_CHANNEL_LAYER
|
||||||
|
from channels.layers import channel_layers
|
||||||
from django.contrib.messages.storage.base import Message
|
from django.contrib.messages.storage.base import Message
|
||||||
from django.contrib.messages.storage.session import SessionStorage
|
from django.contrib.messages.storage.session import SessionStorage
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
@ -10,13 +11,21 @@ SESSION_KEY = "_messages"
|
||||||
CACHE_PREFIX = "goauthentik.io/root/messages_"
|
CACHE_PREFIX = "goauthentik.io/root/messages_"
|
||||||
|
|
||||||
|
|
||||||
|
async def closing_send(channel, message):
|
||||||
|
"""Wrapper around layer send that closes the connection"""
|
||||||
|
# See https://github.com/django/channels_redis/issues/332
|
||||||
|
# TODO: Remove this after channels_redis 4.1 is released
|
||||||
|
channel_layer = channel_layers.make_backend(DEFAULT_CHANNEL_LAYER)
|
||||||
|
await channel_layer.send(channel, message)
|
||||||
|
await channel_layer.close_pools()
|
||||||
|
|
||||||
|
|
||||||
class ChannelsStorage(SessionStorage):
|
class ChannelsStorage(SessionStorage):
|
||||||
"""Send contrib.messages over websocket"""
|
"""Send contrib.messages over websocket"""
|
||||||
|
|
||||||
def __init__(self, request: HttpRequest) -> None:
|
def __init__(self, request: HttpRequest) -> None:
|
||||||
# pyright: reportGeneralTypeIssues=false
|
# pyright: reportGeneralTypeIssues=false
|
||||||
super().__init__(request)
|
super().__init__(request)
|
||||||
self.channel = get_channel_layer()
|
|
||||||
|
|
||||||
def _store(self, messages: list[Message], response, *args, **kwargs):
|
def _store(self, messages: list[Message], response, *args, **kwargs):
|
||||||
prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_"
|
prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_"
|
||||||
|
@ -28,7 +37,7 @@ class ChannelsStorage(SessionStorage):
|
||||||
for key in keys:
|
for key in keys:
|
||||||
uid = key.replace(prefix, "")
|
uid = key.replace(prefix, "")
|
||||||
for message in messages:
|
for message in messages:
|
||||||
async_to_sync(self.channel.send)(
|
async_to_sync(closing_send)(
|
||||||
uid,
|
uid,
|
||||||
{
|
{
|
||||||
"type": "event.update",
|
"type": "event.update",
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
ignore = [
|
ignore = ["**/migrations/**", "**/node_modules/**"]
|
||||||
"**/migrations/**",
|
|
||||||
"**/node_modules/**"
|
|
||||||
]
|
|
||||||
reportMissingTypeStubs = false
|
reportMissingTypeStubs = false
|
||||||
strictParameterNoneValue = true
|
strictParameterNoneValue = true
|
||||||
strictDictionaryInference = true
|
strictDictionaryInference = true
|
||||||
|
@ -63,14 +60,7 @@ exclude_lines = [
|
||||||
show_missing = true
|
show_missing = true
|
||||||
|
|
||||||
[tool.pylint.basic]
|
[tool.pylint.basic]
|
||||||
good-names = [
|
good-names = ["pk", "id", "i", "j", "k", "_"]
|
||||||
"pk",
|
|
||||||
"id",
|
|
||||||
"i",
|
|
||||||
"j",
|
|
||||||
"k",
|
|
||||||
"_",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.pylint.master]
|
[tool.pylint.master]
|
||||||
disable = [
|
disable = [
|
||||||
|
@ -85,6 +75,7 @@ disable = [
|
||||||
"protected-access",
|
"protected-access",
|
||||||
"unused-argument",
|
"unused-argument",
|
||||||
"raise-missing-from",
|
"raise-missing-from",
|
||||||
|
"fixme",
|
||||||
# To preserve django's translation function we need to use %-formatting
|
# To preserve django's translation function we need to use %-formatting
|
||||||
"consider-using-f-string",
|
"consider-using-f-string",
|
||||||
]
|
]
|
||||||
|
@ -120,7 +111,7 @@ authors = ["authentik Team <hello@goauthentik.io>"]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
celery = "*"
|
celery = "*"
|
||||||
channels = {version = "*", extras = ["daphne"]}
|
channels = { version = "*", extras = ["daphne"] }
|
||||||
channels-redis = "*"
|
channels-redis = "*"
|
||||||
codespell = "*"
|
codespell = "*"
|
||||||
colorama = "*"
|
colorama = "*"
|
||||||
|
@ -147,7 +138,7 @@ gunicorn = "*"
|
||||||
kubernetes = "*"
|
kubernetes = "*"
|
||||||
ldap3 = "*"
|
ldap3 = "*"
|
||||||
lxml = "*"
|
lxml = "*"
|
||||||
opencontainers = {extras = ["reggie"],version = "*"}
|
opencontainers = { extras = ["reggie"], version = "*" }
|
||||||
packaging = "*"
|
packaging = "*"
|
||||||
paramiko = "*"
|
paramiko = "*"
|
||||||
psycopg2-binary = "*"
|
psycopg2-binary = "*"
|
||||||
|
@ -163,8 +154,8 @@ swagger-spec-validator = "*"
|
||||||
twilio = "*"
|
twilio = "*"
|
||||||
twisted = "*"
|
twisted = "*"
|
||||||
ua-parser = "*"
|
ua-parser = "*"
|
||||||
urllib3 = {extras = ["secure"],version = "*"}
|
urllib3 = { extras = ["secure"], version = "*" }
|
||||||
uvicorn = {extras = ["standard"],version = "*"}
|
uvicorn = { extras = ["standard"], version = "*" }
|
||||||
webauthn = "*"
|
webauthn = "*"
|
||||||
wsproto = "*"
|
wsproto = "*"
|
||||||
xmlsec = "*"
|
xmlsec = "*"
|
||||||
|
@ -176,7 +167,7 @@ bandit = "*"
|
||||||
black = "*"
|
black = "*"
|
||||||
bump2version = "*"
|
bump2version = "*"
|
||||||
colorama = "*"
|
colorama = "*"
|
||||||
coverage = {extras = ["toml"],version = "*"}
|
coverage = { extras = ["toml"], version = "*" }
|
||||||
importlib-metadata = "*"
|
importlib-metadata = "*"
|
||||||
pylint = "*"
|
pylint = "*"
|
||||||
pylint-django = "*"
|
pylint-django = "*"
|
||||||
|
|
Reference in New Issue