/*  Copyright (C) 2023-2024 Andreas Shimokawa, José Rebelo, Yoran Vulker

    This file is part of Gadgetbridge.

    Gadgetbridge is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published
    by the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    Gadgetbridge is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>. */
package nodomain.freeyourgadget.gadgetbridge.service.devices.xiaomi;

import android.content.Context;
import android.content.SharedPreferences;
import android.os.Build;
import android.widget.Toast;

import androidx.annotation.Nullable;

import com.google.protobuf.ByteString;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.shaded.crypto.CryptoException;
import org.bouncycastle.shaded.crypto.engines.AESEngine;
import org.bouncycastle.shaded.crypto.modes.CCMBlockCipher;
import org.bouncycastle.shaded.crypto.params.AEADParameters;
import org.bouncycastle.shaded.crypto.params.KeyParameter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Locale;

import javax.crypto.Cipher;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import nodomain.freeyourgadget.gadgetbridge.GBApplication;
import nodomain.freeyourgadget.gadgetbridge.R;
import nodomain.freeyourgadget.gadgetbridge.impl.GBDevice;
import nodomain.freeyourgadget.gadgetbridge.proto.xiaomi.XiaomiProto;
import nodomain.freeyourgadget.gadgetbridge.service.devices.xiaomi.services.AbstractXiaomiService;
import nodomain.freeyourgadget.gadgetbridge.util.GB;

public class XiaomiAuthService extends AbstractXiaomiService {
    private static final Logger LOG = LoggerFactory.getLogger(XiaomiAuthService.class);


    public static final int COMMAND_TYPE = 1;

    public static final int CMD_SEND_USERID = 5;
    public static final int CMD_NONCE = 26;
    public static final int CMD_AUTH = 27;

    private boolean encryptionInitialized = false;
    private boolean checkDecryptionMac = true;

    private final byte[] secretKey = new byte[16];
    private final byte[] nonce = new byte[16];
    private final byte[] encryptionKey = new byte[16];
    private final byte[] decryptionKey = new byte[16];
    private final byte[] encryptionNonce = new byte[4];
    private final byte[] decryptionNonce = new byte[4];

    public XiaomiAuthService(final XiaomiSupport support) {
        super(support);
    }

    public boolean isEncryptionInitialized() {
        return encryptionInitialized;
    }

    protected void startEncryptedHandshake() {
        encryptionInitialized = false;

        System.arraycopy(getSecretKey(getSupport().getDevice()), 0, secretKey, 0, 16);
        new SecureRandom().nextBytes(nonce);

        getSupport().sendCommand("auth step 1", buildNonceCommand(nonce));
    }

    protected void startClearTextHandshake() {
        final XiaomiProto.Auth auth = XiaomiProto.Auth.newBuilder()
                .setUserId(getUserId(getSupport().getDevice()))
                .build();

        final XiaomiProto.Command command = XiaomiProto.Command.newBuilder()
                .setType(XiaomiAuthService.COMMAND_TYPE)
                .setSubtype(XiaomiAuthService.CMD_SEND_USERID)
                .setAuth(auth)
                .build();

        getSupport().sendCommand("auth step 1", command);
    }

    @Override
    public void setContext(final Context context) {
        super.setContext(context);
        this.checkDecryptionMac = getCoordinator().checkDecryptionMac();
    }

    @Override
    public void handleCommand(final XiaomiProto.Command cmd) {
        if (cmd.getType() != COMMAND_TYPE) {
            throw new IllegalArgumentException("Not an auth command");
        }

        switch (cmd.getSubtype()) {
            case CMD_NONCE: {
                LOG.debug("Got watch nonce");

                // Watch nonce
                final XiaomiProto.Command command = handleWatchNonce(cmd.getAuth().getWatchNonce());

                if (command == null) {
                    GB.toast(getSupport().getContext(), R.string.authentication_failed_check_key, Toast.LENGTH_LONG, GB.WARN);
                    LOG.error("handleWatchNonce returned null, disconnecting");
                    final GBDevice device = getSupport().getDevice();

                    if (device != null) {
                        GBApplication.deviceService(device).disconnect();
                    }

                    return;
                }

                getSupport().sendCommand("auth step 2", command);
                break;
            }

            case CMD_AUTH:
            case CMD_SEND_USERID: {
                if (cmd.getSubtype() == CMD_AUTH || cmd.getAuth().getStatus() == 1) {
                    encryptionInitialized = cmd.getSubtype() == CMD_AUTH;

                    LOG.info("Authenticated, further communications are {}", encryptionInitialized ? "encrypted" : "in plaintext");

                    getSupport().getDevice().setUpdateState(GBDevice.State.INITIALIZED, getSupport().getContext());

                    getSupport().onAuthSuccess();
                } else {
                    LOG.warn("Authentication failed, subtype={}, status={}", cmd.getSubtype(), cmd.getStatus());
                    GB.toast(getSupport().getContext(), R.string.authentication_failed_check_key, Toast.LENGTH_LONG, GB.WARN);

                    final GBDevice device = getSupport().getDevice();
                    if (device != null) {
                        GBApplication.deviceService(device).disconnect();
                    }
                }
                break;
            }
            default:
                LOG.warn("Unknown auth payload subtype {}", cmd.getSubtype());
        }
    }

    public byte[] encrypt(final byte[] arr, final int i) {
        final ByteBuffer packetNonce = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN)
                .put(encryptionNonce)
                .putInt(0)
                .putInt(i);

        try {
            return encrypt(encryptionKey, packetNonce.array(), arr);
        } catch (final CryptoException e) {
            throw new RuntimeException("failed to encrypt", e);
        }
    }

    public byte[] decrypt(final byte[] arr) {
        final ByteBuffer packetNonce = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN);
        packetNonce.put(decryptionNonce);
        packetNonce.putInt(0);
        packetNonce.putInt(0);

        try {
            return decrypt(decryptionKey, packetNonce.array(), arr, checkDecryptionMac);
        } catch (final CryptoException e) {
            throw new RuntimeException("failed to decrypt", e);
        }
    }

    @Nullable
    private XiaomiProto.Command handleWatchNonce(final XiaomiProto.WatchNonce watchNonce) {
        final byte[] step2hmac = computeAuthStep3Hmac(secretKey, nonce, watchNonce.getNonce().toByteArray());

        System.arraycopy(step2hmac, 0, decryptionKey, 0, 16);
        System.arraycopy(step2hmac, 16, encryptionKey, 0, 16);
        System.arraycopy(step2hmac, 32, decryptionNonce, 0, 4);
        System.arraycopy(step2hmac, 36, encryptionNonce, 0, 4);

        LOG.debug("decryptionKey: {}", GB.hexdump(decryptionKey));
        LOG.debug("encryptionKey: {}", GB.hexdump(encryptionKey));
        LOG.debug("decryptionNonce: {}", GB.hexdump(decryptionNonce));
        LOG.debug("encryptionNonce: {}", GB.hexdump(encryptionNonce));

        final byte[] decryptionConfirmation = hmacSHA256(decryptionKey, ArrayUtils.addAll(watchNonce.getNonce().toByteArray(), nonce));
        if (!Arrays.equals(decryptionConfirmation, watchNonce.getHmac().toByteArray())) {
            LOG.warn("Watch hmac mismatch");
            return null;
        }

        final XiaomiProto.AuthDeviceInfo authDeviceInfo = XiaomiProto.AuthDeviceInfo.newBuilder()
                .setUnknown1(0) // TODO ?
                .setPhoneApiLevel(Build.VERSION.SDK_INT)
                .setPhoneName(Build.MODEL)
                .setUnknown3(224) // TODO ?
                // TODO region should be actual device region?
                .setRegion(Locale.getDefault().getLanguage().substring(0, 2).toUpperCase(Locale.ROOT))
                .build();

        final byte[] encryptedNonces = hmacSHA256(encryptionKey, ArrayUtils.addAll(nonce, watchNonce.getNonce().toByteArray()));
        final byte[] encryptedDeviceInfo = encrypt(authDeviceInfo.toByteArray(), 0);
        final XiaomiProto.AuthStep3 authStep3 = XiaomiProto.AuthStep3.newBuilder()
                .setEncryptedNonces(ByteString.copyFrom(encryptedNonces))
                .setEncryptedDeviceInfo(ByteString.copyFrom(encryptedDeviceInfo))
                .build();

        final XiaomiProto.Command.Builder cmd = XiaomiProto.Command.newBuilder();
        cmd.setType(COMMAND_TYPE);
        cmd.setSubtype(CMD_AUTH);

        final XiaomiProto.Auth.Builder auth = XiaomiProto.Auth.newBuilder();
        auth.setAuthStep3(authStep3);

        return cmd.setAuth(auth.build()).build();
    }

    public static XiaomiProto.Command buildNonceCommand(final byte[] nonce) {
        final XiaomiProto.PhoneNonce.Builder phoneNonce = XiaomiProto.PhoneNonce.newBuilder();
        phoneNonce.setNonce(ByteString.copyFrom(nonce));

        final XiaomiProto.Auth.Builder auth = XiaomiProto.Auth.newBuilder();
        auth.setPhoneNonce(phoneNonce.build());

        final XiaomiProto.Command.Builder command = XiaomiProto.Command.newBuilder();
        command.setType(COMMAND_TYPE);
        command.setSubtype(CMD_NONCE);
        command.setAuth(auth.build());
        return command.build();
    }

    public static byte[] computeAuthStep3Hmac(final byte[] secretKey,
                                              final byte[] phoneNonce,
                                              final byte[] watchNonce) {
        final byte[] miwearAuthBytes = "miwear-auth".getBytes();

        final Mac mac;
        try {
            mac = Mac.getInstance("HmacSHA256");
            // Compute the actual key and re-initialize the mac
            mac.init(new SecretKeySpec(ArrayUtils.addAll(phoneNonce, watchNonce), "HmacSHA256"));
            final byte[] hmacKeyBytes = mac.doFinal(secretKey);
            final SecretKeySpec key = new SecretKeySpec(hmacKeyBytes, "HmacSHA256");
            mac.init(key);
        } catch (final NoSuchAlgorithmException | InvalidKeyException e) {
            throw new IllegalStateException("Failed to initialize hmac for auth step 2", e);
        }

        final byte[] output = new byte[64];
        byte[] tmp = new byte[0];
        byte b = 1;
        int i = 0;
        while (i < output.length) {
            mac.update(tmp);
            mac.update(miwearAuthBytes);
            mac.update(b);
            tmp = mac.doFinal();
            for (int j = 0; j < tmp.length && i < output.length; j++, i++) {
                output[i] = tmp[j];
            }
            b++;
        }
        return output;
    }

    protected static byte[] getSecretKey(final GBDevice device) {
        final byte[] authKeyBytes = new byte[16];

        final SharedPreferences sharedPrefs = GBApplication.getDeviceSpecificSharedPrefs(device.getAddress());

        final String authKey = sharedPrefs.getString("authkey", "").trim();
        if (StringUtils.isNotBlank(authKey)) {
            final byte[] srcBytes;
            // Allow both with and without 0x, to avoid user mistakes
            if (authKey.length() == 34 && authKey.startsWith("0x")) {
                srcBytes = GB.hexStringToByteArray(authKey.trim().substring(2));
            } else {
                srcBytes = GB.hexStringToByteArray(authKey.trim());
            }
            System.arraycopy(srcBytes, 0, authKeyBytes, 0, Math.min(srcBytes.length, 16));
        }

        return authKeyBytes;
    }

    protected static String getUserId(final GBDevice device) {
        final SharedPreferences sharedPrefs = GBApplication.getDeviceSpecificSharedPrefs(device.getAddress());

        final String authKey = sharedPrefs.getString("authkey", null);
        if (StringUtils.isNotBlank(authKey)) {
            return authKey;
        }

        return "0000000000";
    }

    protected static byte[] hmacSHA256(final byte[] key, final byte[] input) {
        try {
            final Mac mac = Mac.getInstance("HmacSHA256");
            mac.init(new SecretKeySpec(key, "HmacSHA256"));
            return mac.doFinal(input);
        } catch (final Exception e) {
            throw new RuntimeException("Failed to hmac", e);
        }
    }

    public static byte[] encrypt(final byte[] key, final byte[] nonce, final byte[] payload) throws
            CryptoException {
        final CCMBlockCipher cipher = createBlockCipher(true, new SecretKeySpec(key, "AES"), 32, nonce);
        final byte[] out = new byte[cipher.getOutputSize(payload.length)];
        final int outBytes = cipher.processBytes(payload, 0, payload.length, out, 0);
        cipher.doFinal(out, outBytes);
        return out;
    }

    public static byte[] decrypt(final byte[] key,
                                 final byte[] nonce,
                                 final byte[] encryptedPayload,
                                 final boolean checkMac) throws CryptoException {
        final int macSizeBits = checkMac ? 32 : 0;
        final int actualEncryptedLength = checkMac ? encryptedPayload.length : encryptedPayload.length - 4;
        final CCMBlockCipher cipher = createBlockCipher(false, new SecretKeySpec(key, "AES"), macSizeBits, nonce);
        final byte[] decrypted = new byte[cipher.getOutputSize(actualEncryptedLength)];
        cipher.doFinal(decrypted, cipher.processBytes(encryptedPayload, 0, actualEncryptedLength, decrypted, 0));
        return decrypted;
    }

    public static CCMBlockCipher createBlockCipher(final boolean forEncrypt,
                                                   final SecretKey secretKey,
                                                   final int macSizeBits,
                                                   final byte[] nonce) {
        final AESEngine aesFastEngine = new AESEngine();
        aesFastEngine.init(forEncrypt, new KeyParameter(secretKey.getEncoded()));
        final CCMBlockCipher blockCipher = new CCMBlockCipher(aesFastEngine);
        blockCipher.init(forEncrypt, new AEADParameters(new KeyParameter(secretKey.getEncoded()), macSizeBits, nonce, null));
        return blockCipher;
    }

    public byte[] encryptV2(final byte[] message) {
        try {
            // I wish I was kidding
            return ctrCrypt(Cipher.ENCRYPT_MODE, encryptionKey, encryptionKey, message);
        } catch (final GeneralSecurityException ex) {
            throw new RuntimeException("failed to encrypt message", ex);
        }
    }

    public byte[] decryptV2(final byte[] ciphertext) {
        try {
            // I wish I was kidding
            return ctrCrypt(Cipher.DECRYPT_MODE, decryptionKey, decryptionKey, ciphertext);
        } catch (final GeneralSecurityException ex) {
            throw new RuntimeException("failed to decrypt message", ex);
        }
    }

    public byte[] ctrCrypt(final int op, final byte[] key, final byte[] iv, final byte[] message) throws GeneralSecurityException {
        final Cipher cipher = Cipher.getInstance("AES/CTR/NoPadding");
        cipher.init(
                op,
                new SecretKeySpec(key, "AES"),
                new IvParameterSpec(iv)
        );
        return cipher.doFinal(message);
    }
}
