#!/usr/bin/env python3
import argparse
import base64
import hashlib
import os.path
import shutil
import tempfile
import zipfile
from typing import NamedTuple


class Wheel(NamedTuple):
    src: str
    plat: str
    exe: str = 'sentry-cli'


WHEELS = (
    Wheel(
        src='sentry-cli-Darwin-arm64',
        plat='macosx_11_0_arm64',
    ),
    Wheel(
        src='sentry-cli-Darwin-universal',
        plat='macosx_11_0_universal2',
    ),
    Wheel(
        src='sentry-cli-Darwin-x86_64',
        plat='macosx_10_15_x86_64',
    ),
    Wheel(
        src='sentry-cli-Linux-aarch64',
        plat='manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_2_aarch64',
    ),
    Wheel(
        src='sentry-cli-Linux-armv7',
        plat='manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_2_armv7l',
    ),
    Wheel(
        src='sentry-cli-Linux-i686',
        plat='manylinux_2_17_i686.manylinux2014_i686.musllinux_1_2_i686',
    ),
    Wheel(
        src='sentry-cli-Linux-x86_64',
        plat='manylinux_2_17_x86_64.manylinux2014_x86_64.musllinux_1_2_x86_64',
    ),
    Wheel(
        src='sentry-cli-Windows-i686.exe',
        plat='win32',
        exe='sentry-cli.exe',
    ),
    Wheel(
        src='sentry-cli-Windows-x86_64.exe',
        plat='win_amd64',
        exe='sentry-cli.exe',
    ),
)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('--binaries', required=True)
    parser.add_argument('--base', required=True)
    parser.add_argument('--dest', required=True)
    args = parser.parse_args()

    expected = {wheel.src for wheel in WHEELS}
    received = set(os.listdir(args.binaries))
    if expected < received:
        raise SystemExit(
            f'Unexpected binaries:\n\n'
            f'- extra: {", ".join(sorted(received - expected))}\n'
            f'- missing: {", ".join(sorted(expected - received))}'
        )

    sdist_path = wheel_path = None
    for fname in os.listdir(args.base):
        if fname.endswith('.tar.gz'):
            sdist_path = os.path.join(args.base, fname)
        elif fname.endswith('.whl'):
            wheel_path = os.path.join(args.base, fname)
        else:
            raise SystemExit(f'unexpected file in `--base`: {fname}')

    if sdist_path is None or wheel_path is None:
        raise SystemExit('expected wheel and sdist in `--base`')

    os.makedirs(args.dest, exist_ok=True)
    shutil.copy(sdist_path, args.dest)

    for wheel in WHEELS:
        binary_src = os.path.join(args.binaries, wheel.src)
        binary_size = os.stat(binary_src).st_size
        with open(binary_src, 'rb') as bf:
            digest = hashlib.sha256(bf.read()).digest()
            digest_b64 = base64.urlsafe_b64encode(digest).rstrip(b'=').decode()

        basename = os.path.basename(wheel_path)
        wheelname, _ = os.path.splitext(basename)
        name, version, py, abi, plat = wheelname.split('-')

        with tempfile.TemporaryDirectory() as tmp:
            with zipfile.ZipFile(wheel_path) as zipf:
                zipf.extractall(tmp)

            distinfo = os.path.join(tmp, f'{name}-{version}.dist-info')
            scripts = os.path.join(tmp, f'{name}-{version}.data', 'scripts')

            # replace the script binary with our copy
            os.remove(os.path.join(scripts, 'sentry-cli'))
            shutil.copy(binary_src, os.path.join(scripts, wheel.exe))

            # rewrite RECORD to include the new file
            record_fname = os.path.join(distinfo, 'RECORD')
            with open(record_fname) as f:
                record_lines = list(f)

            record = f'{name}-{version}.data/scripts/sentry-cli,'
            for i, line in enumerate(record_lines):
                if line.startswith(record):
                    record_lines[i] = (
                        f'{name}-{version}.data/scripts/{wheel.exe},'
                        f'sha256={digest_b64},'
                        f'{binary_size}\n'
                    )
                    break
            else:
                raise SystemExit(f'could not find {record!r} in RECORD')

            with open(record_fname, 'w') as f:
                f.writelines(record_lines)

            # rewrite WHEEL to have the new tags
            wheel_fname = os.path.join(distinfo, 'WHEEL')
            with open(wheel_fname) as f:
                wheel_lines = list(f)

            for i, line in enumerate(wheel_lines):
                if line.startswith('Tag: '):
                    wheel_lines[i:i + 1] = [
                        f'Tag: {py}-{abi}-{plat}\n'
                        for plat in wheel.plat.split('.')
                    ]
                    break
            else:
                raise SystemExit("could not find 'Tag: ' in WHEEL")

            with open(wheel_fname, 'w') as f:
                f.writelines(wheel_lines)

            # write out the final zip
            new_basename = f'{name}-{version}-{py}-{abi}-{wheel.plat}.whl'
            tmp_new_wheel = os.path.join(tmp, new_basename)
            fnames = sorted(
                os.path.join(root, fname)
                for root, _, fnames in os.walk(tmp)
                for fname in fnames
            )
            with zipfile.ZipFile(tmp_new_wheel, 'w') as zipf:
                for fname in fnames:
                    zinfo = zipfile.ZipInfo(os.path.relpath(fname, tmp))
                    if '/scripts/' in zinfo.filename:
                        zinfo.external_attr = 0o100755 << 16
                    with open(fname, 'rb') as fb:
                        zipf.writestr(zinfo, fb.read())

            # move into dest
            shutil.move(tmp_new_wheel, args.dest)

    return 0


if __name__ == '__main__':
    raise SystemExit(main())
