#!/usr/bin/python3
#
# verify_combine_sigs
#
# Helper script for shim-signed
#
# Microsoft currently only return signed binaries with one signature;
# if they are signing with more than key/cert, then we will get
# multiple separate signed binaries, one per key/cert.
#
# Check that all our signed shims are signed with an expected key that
# we can remove and re-add; error out otherwise.
#
# Then finally we will add all those signatures to one output binary.
#
# Order of the listed signed shims matters here - list them *in the
# same order* as the signatures we'd like in the final binary. It's
# recommended to do this in the order:
#
# <oldest CA>
# ...
# <newest CA>
#
# as that is most likely to work with older firmware implementations.
# That's most easily achieved by naming the signatures like
# 0001-first-CA.crt, 0002-second-CA.crt, etc.

import os
import glob
import re
import sys
import subprocess
import argparse
import shutil
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import pkcs7

# Allowed certificates - each binary must be signed with a certificate
# from this set.
SIGN_CERTS = "0*.crt"

# Path to sbverify - special build for now
SBVERIFY = "/usr/bin/sbverify"


def parse_args():
    parser = argparse.ArgumentParser(description="verify_combine_sigs")
    parser.add_argument(
        "--efi_arch", "-a", help="EFI architecture for binaries", required=True
    )
    parser.add_argument("signed", help="signed binaries to verify/merge", nargs="+")
    args = parser.parse_args()
    return args


def grab_cert_details(check_cert: str) -> (str, str):
    """
    Parse a certificate from disk and grab out:
    - a hash of the certificate for comparison later
    - the Subject test in a format matching the output of sbverify -l
    """
    print(f"Loading details from {check_cert}")
    with open(check_cert, "rb") as inf:
        pem_data = inf.read()
    cert = x509.load_pem_x509_certificate(pem_data)
    subject = "/" + "/".join([x.value for x in cert.subject])
    sha1 = cert.fingerprint(hashes.SHA1()).hex()
    sha256 = cert.fingerprint(hashes.SHA256()).hex()
    print(f" - {subject}")
    print(f" - sha1sum {sha1}")
    print(f" - sha256sum {sha256}")
    return subject, sha1, sha256


def list_signatures(signed_filename: str):
    cmd = [SBVERIFY, "-l", signed_filename]
    output = subprocess.check_output(cmd, stderr=subprocess.DEVNULL, text=True)
    for line in output.splitlines():
        print(line)


def verify_signature(signed_filename: str, certs: list[str]):
    cmd = [SBVERIFY, signed_filename]
    for cert in certs:
        cmd.extend(["--cert", cert])
    subprocess.check_output(cmd, stderr=subprocess.STDOUT)


def parse_sbverify(signed_filename: str) -> list[dict]:
    cmd = [SBVERIFY, "-l", signed_filename]
    output = subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT)
    signatures = []
    state = 0
    for line in output.splitlines():
        if line.startswith("signature"):
            state = 1
            continue
        if state == 1 and line.startswith("image signature issuers:"):
            state = 2
            continue
        if state == 2:
            issuer = line[3:]
            signatures.append(issuer)
            state = 0
            continue

    return signatures


def detach_signature(signed_filename: str, signum: int, outfile: str):
    """
    Detach a numbered signature from the signature table in a
    signed binary.
    """
    cmd = [
        "sbattach",
        "--signum",
        f"{signum}",
        "--detach",
        outfile,
        signed_filename,
    ]
    subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT)


def certs_in_detached_signature(detached: str) -> list[dict]:
    """
    Extract certificate details from a PKCS7 blob.
    """
    with open(detached, "rb") as inf:
        pkcs7_data = inf.read()

    output = []
    certs = pkcs7.load_der_pkcs7_certificates(pkcs7_data)
    for cert in certs:
        subject = "/" + "/".join([x.value for x in cert.subject])
        sha1 = cert.fingerprint(hashes.SHA1()).hex()
        sha256 = cert.fingerprint(hashes.SHA256()).hex()
        output.append({"sha1": sha1, "sha256": sha256, "subject": subject})

    # We want them in the order CA -> leaf
    output.reverse()
    return output


def attach_sig(sigfile: str, unsigned: str):
    """
    Use sbattach to add a signature onto a binary.
    """
    cmd = [
        "sbattach",
        "--attach",
        sigfile,
        unsigned,
    ]
    subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT)


def checksum_file(filename: str) -> str:
    """
    Calculate the sha256sum of a file
    """
    with open(filename, "rb") as inf:
        data = inf.read()
    hashalg = hashes.SHA256()
    hasher = hashes.Hash(hashalg, backend=default_backend())
    hasher.update(data)
    digest = hasher.finalize()
    return digest.hex()


def main():

    args = parse_args()

    print("Loading details of all the expected certificates")
    print("==========")
    known_certs = {}
    for check_cert in sorted(glob.glob(SIGN_CERTS)):
        subject, sha1, sha256 = grab_cert_details(check_cert)
        known_certs[subject] = {
            "sha1": sha1,
            "sha256": sha256,
            "filename": check_cert
        }
    print("")

    print(f"Verifying signatures for arch {args.efi_arch} ...")
    print("==========\n")

    build = "build"
    shutil.rmtree(build, ignore_errors=True)
    os.mkdir(build)

    for signed in args.signed:

        print(f"Checking {signed}")
        print("----------\n")

        # Verify that the image is signed and valid
        print("Looking for any valid checksum and signature")
        try:
            verify_signature(signed, sorted(glob.glob(SIGN_CERTS)))
        except Exception as exc:
            print(f"Invalid signature on {signed}: {exc}")
            sys.exit(1)

        signatures = parse_sbverify(signed)
        num = len(signatures)
        if num != 1:
            print(f"Only expected 1 signature, but {signed} has {num}!")
            print("Abort")
            sys.exit(1)
        # else
        print(f"{signed} has 1 signature, good!")

        # Now see what signature we have. We'll have to extract the
        # signature table here, then extract the list of certificates
        # included in the 1 signature we have.
        detached_sig = "detached.sig"
        detach_signature(signed, 1, detached_sig)
        sig_certs = certs_in_detached_signature(detached_sig)

        matched_filename = None

        print("certs attached:")
        for cert in sig_certs:
            print(f'  - {cert["subject"]}')
            print(f'  - sha1 {cert["sha1"]}')
            print(f'  - sha256 {cert["sha256"]}')

        # Now we need to compare the root certificate there to our
        # known certificates
        for subject, data in known_certs.items():
            if (
                sig_certs[0]["subject"] == subject
                and sig_certs[0]["sha256"] == data["sha256"]
            ):
                print(
                    f'\nroot certificate matches a known certificate ({data["filename"]})'
                )
                matched_filename = data["filename"]
                matched_sha1 = data["sha1"]

        if matched_filename is None:
            print(f"\nERROR: {signed} signature unknown, abort!")
            sys.exit(1)

        # Move the detached signature to one side, for future use
        new_filename = os.path.join(build, f"detached-{matched_filename}")
        shutil.move(detached_sig, new_filename)

        # And write out the sha1 checksum of the cert for later use
        sha1_filename = os.path.join(build, f"sha1-{matched_filename}")
        with open(sha1_filename, "w") as outf:
            output = (':'.join(re.findall('..', matched_sha1)))
            outf.write(output)

        # Copy our matching unsigned binary into the ${BUILD} directory.
        unsigned = f"{build}/shim{args.efi_arch}.efi.signed"
        shutil.copy(
            f"/usr/lib/shim/shim{args.efi_arch}.efi",
            unsigned,
        )

        # Attach the signature to our unsigned binary, so we know that
        # the binary has not been tampered with during the signing
        # process.
        print("Checking the signature applies to our original binary")
        attach_sig(new_filename, unsigned)
        print("  Signature applies OK")

        # Now compare the result to the signed binary we were given
        print("Comparing the signed binaries")
        old_sha = checksum_file(signed)
        print(f"{old_sha}  {signed}")
        new_sha = checksum_file(unsigned)
        print(f"{new_sha}  {unsigned}")

        if old_sha != new_sha:
            print("\nERROR: signatures don't match, abort!")
            sys.exit(1)

        print("Binaries match!\n")

    # If we've got this far, then we've checked all the binaries we
    # were given and things look OK. Now we want to build a single
    # output image with all the signatures attached.

    print(f"Building final combined shim for arch {args.efi_arch} ...")
    print("==========")

    shutil.copy(
        f"/usr/lib/shim/shim{args.efi_arch}.efi",
        unsigned,
    )

    for sig in sorted(glob.glob(f"{build}/detached-*")):
        print(f"Adding signature {sig}")
        attach_sig(sig, unsigned)

    # Stick the signature fingerprints together
    with open(f"{unsigned}-signatures", "w") as outf:
        for fp in sorted(glob.glob(f"{build}/sha1-*")):
            with open(fp) as inf:
                fingerprint = inf.read()
            outf.write(fingerprint + "\n")

    # And finally show the list of signatures
    print(f"Signatures on {unsigned} :")
    list_signatures(unsigned)


if __name__ == "__main__":
    main()
