from channels.db import database_sync_to_async
from channels.middleware import BaseMiddleware
from django.contrib.auth import get_user_model
from django.contrib.sessions.models import Session
User = get_user_model()
[docs]
class SessionAuthMiddleware(BaseMiddleware):
"""
Custom middleware to authenticate WebSocket connections using Django session.
Reads sessionid from cookies (sent via credentials: 'include').
"""
async def __call__(self, scope, receive, send):
# Get cookies from scope headers
headers = dict(scope.get('headers', []))
cookie_header = headers.get(b'cookie', b'').decode('utf-8')
# Parse sessionid from cookies
session_id = None
for cookie in cookie_header.split(';'):
cookie = cookie.strip()
if cookie.startswith('sessionid='):
session_id = cookie.split('=', 1)[1]
break
# Authenticate user from session
if session_id:
scope['user'] = await self.get_user_from_session(session_id)
else:
scope['user'] = None
return await super().__call__(scope, receive, send)
[docs]
@database_sync_to_async
def get_user_from_session(self, session_id):
"""Get user from session ID."""
try:
session = Session.objects.get(session_key=session_id)
session_data = session.get_decoded()
user_id = session_data.get('_auth_user_id')
if user_id:
return User.objects.get(pk=user_id)
except (Session.DoesNotExist, User.DoesNotExist):
pass
return None