123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import hashlib
- from typing import Any
- from urllib.parse import urlparse
- import datetime
- from base64 import b64encode, b64decode
- from Crypto import Random
- from Crypto.PublicKey import RSA
- from Crypto.Signature import pkcs1_15 # pylint: disable=no-name-in-module
- from Crypto.Hash import SHA256
- MAX_SIGNATURE_AGE = 300
- def create_key_pair() -> tuple(str, str):
- """Creates new key pair for a new user"""
- random_generator = Random.new().read
- key = RSA.generate(1024, random_generator)
- private_key = key.export_key().decode("utf8")
- public_key = key.publickey().export_key().decode("utf8")
- return private_key, public_key
- def make_signature(sender, destination, date, digest):
- """Sign outgoing message with a private key"""
- inbox_parts = urlparse(destination)
- signature_headers = [
- f"(request-target): post {inbox_parts.path}",
- f"host: {inbox_parts.netloc}",
- f"date: {date}",
- f"digest: {digest}",
- ]
- message_to_sign = "\n".join(signature_headers)
- signer = pkcs1_15.new(RSA.import_key(sender.key_pair.private_key))
- signed_message = signer.sign(SHA256.new(message_to_sign.encode("utf8")))
- signature = {
- "keyId": f"{sender.remote_id}#main-key",
- "algorithm": "rsa-sha256",
- "headers": "(request-target) host date digest",
- "signature": b64encode(signed_message).decode("utf8"),
- }
- return ",".join(f'{k}="{v}"' for (k, v) in signature.items())
- def make_digest(data):
- """creates a message digest for signing"""
- return "SHA-256=" + b64encode(
- hashlib.sha256(data.encode("utf-8")).digest()
- ).decode("utf-8")
- def verify_digest(request):
- """checks if a digest is syntactically valid and matches the message"""
- algorithm, digest = request.headers["digest"].split("=", 1)
- if algorithm == "SHA-256":
- hash_function = hashlib.sha256
- elif algorithm == "SHA-512":
- hash_function = hashlib.sha512
- else:
- raise ValueError(f"Unsupported hash function: {algorithm}")
- expected = hash_function(request.body).digest()
- if b64decode(digest) != expected:
- raise ValueError("Invalid HTTP Digest header")
- class Signature:
- """read and validate incoming signatures"""
- def __init__(self, key_id, headers, signature):
- self.key_id = key_id
- self.headers = headers
- self.signature = signature
- # pylint: disable=invalid-name
- @classmethod
- def parse(cls, signature: str):
- """Extract and parse signature from an HTTP request signature string"""
- signature_dict = {}
- for pair in signature.split(","):
- k, v = pair.split("=", 1)
- v = v.replace('"', "")
- signature_dict[k] = v
- try:
- key_id = signature_dict["keyId"]
- headers = signature_dict["headers"]
- signature = b64decode(signature_dict["signature"])
- except KeyError:
- raise ValueError("Invalid auth header")
- return cls(key_id, headers, signature)
- def verify(self, public_key, date, request):
- """Verify RSA signature using a public key"""
- """verify rsa signature"""
- if http_date_age(request.headers["date"]) > MAX_SIGNATURE_AGE:
- raise ValueError(f"Request too old: {request.headers['date']}")
- public_key = RSA.import_key(public_key)
- comparison_string = []
- for signed_header_name in self.headers.split(" "):
- if signed_header_name == "(request-target)":
- comparison_string.append(
- f"(request-target): post {request.path}"
- )
- else:
- if signed_header_name == "digest":
- verify_digest(request)
- comparison_string.append(
- f"{signed_header_name}: {request.headers[signed_header_name]}"
- )
- comparison_string = "\n".join(comparison_string)
- signer = pkcs1_15.new(public_key)
- digest = SHA256.new()
- digest.update(comparison_string.encode())
- # raises a ValueError if it fails
- signer.verify(digest, self.signature)
- def http_date_age(datestr):
- """age of a signature in seconds"""
- parsed = datetime.datetime.strptime(datestr, "%a, %d %b %Y %H:%M:%S GMT")
- delta = datetime.datetime.utcnow() - parsed
- return delta.total_seconds()
|