Files
sirh/venv/lib/python3.12/site-packages/simple_sso/utils.py

159 lines
4.9 KiB
Python
Raw Normal View History

2026-04-29 11:52:03 +02:00
import string
from random import SystemRandom
from urllib.parse import urlparse, urlunparse, urljoin
import requests
from django.conf import settings
from django.http import HttpResponse
from django.views.decorators.csrf import csrf_exempt
from itsdangerous import TimedSerializer, SignatureExpired, BadSignature
from simple_sso.exceptions import BadRequest, WebserviceError
random = SystemRandom()
KEY_CHARACTERS = string.ascii_letters + string.digits
PUBLIC_KEY_HEADER = 'x-services-public-key'
def default_gen_secret_key(length=40):
return ''.join([random.choice(KEY_CHARACTERS) for _ in range(length)])
def gen_secret_key(length=40):
generator = getattr(settings, 'SIMPLE_SSO_KEYGENERATOR', default_gen_secret_key)
return generator(length)
def _split_dsn(dsn):
parse_result = urlparse(dsn)
host = parse_result.hostname
if parse_result.port:
host += ':%s' % parse_result.port
base_url = urlunparse((
parse_result.scheme,
host,
parse_result.path,
parse_result.params,
parse_result.query,
parse_result.fragment,
))
return base_url, parse_result.username, parse_result.password
class BaseConsumer(object):
def __init__(self, base_url, public_key, private_key):
self.base_url = base_url
self.public_key = public_key
self.signer = TimedSerializer(private_key)
@classmethod
def from_dsn(cls, dsn):
base_url, public_key, private_key = _split_dsn(dsn)
return cls(base_url, public_key, private_key)
def consume(self, path, data, max_age=None):
if not path.startswith('/'):
raise ValueError("Paths must start with a slash")
signed_data = self.signer.dumps(data)
headers = {
PUBLIC_KEY_HEADER: self.public_key,
'Content-Type': 'application/json',
}
url = self.build_url(path)
body = self.send_request(url, data=signed_data, headers=headers)
return self.handle_response(body, max_age)
def handle_response(self, body, max_age):
return self.signer.loads(body, max_age=max_age)
def send_request(self, url, data, headers):
raise NotImplementedError(
'Implement send_request on BaseConsumer subclasses')
@staticmethod
def raise_for_status(status_code, message):
if status_code == 400:
raise BadRequest(message)
elif status_code >= 300:
raise WebserviceError(message)
def build_url(self, path):
path = path.lstrip('/')
return urljoin(self.base_url, path)
class SyncConsumer(BaseConsumer):
def __init__(self, base_url, public_key, private_key):
super(SyncConsumer, self).__init__(base_url, public_key, private_key)
self.session = requests.session()
def send_request(self, url, data, headers): # pragma: no cover
response = self.session.post(url, data=data, headers=headers)
self.raise_for_status(response.status_code, response.content)
return response.content
class BaseProvider(object):
max_age = None
def provide(self, data):
raise NotImplementedError(
'Subclasses of services.models.Provider must implement '
'the provide method'
)
def get_private_key(self, public_key):
raise NotImplementedError(
'Subclasses of services.models.Provider must implement '
'the get_private_key method'
)
def report_exception(self):
pass
def get_response(self, method, signed_data, get_header):
if method != 'POST':
return 405, ['POST']
public_key = get_header(PUBLIC_KEY_HEADER, None)
if not public_key:
return 400, "No public key"
private_key = self.get_private_key(public_key)
if not private_key:
return 400, "Invalid public key"
signer = TimedSerializer(private_key)
try:
data = signer.loads(signed_data, max_age=self.max_age)
except SignatureExpired:
return 400, "Signature expired"
except BadSignature:
return 400, "Bad Signature"
try:
raw_response_data = self.provide(data)
except:
self.report_exception()
return 400, "Failed to process the request"
response_data = signer.dumps(raw_response_data)
return 200, response_data
def provider_wrapper(provider):
def provider_view(request):
def get_header(key, default):
django_key = 'HTTP_%s' % key.upper().replace('-', '_')
return request.META.get(django_key, default)
method = request.method
if getattr(request, 'body', None):
signed_data = request.body
else:
signed_data = request.raw_post_data
status_code, data = provider.get_response(
method,
signed_data,
get_header,
)
return HttpResponse(data, status=status_code)
return csrf_exempt(provider_view)