package fr.jnda.ipcalc.utils

import java.math.BigInteger
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress

class IPv4Helper(private val ipAddress: String, private val cidrPrefix: Int) {

    private val inetAddress: Inet4Address = InetAddress.getByName(ipAddress) as Inet4Address

    fun getNetmask(): String {
        val mask = -1 shl (32 - cidrPrefix)
        val bytes = byteArrayOf(
            (mask ushr 24).toByte(),
            (mask ushr 16).toByte(),
            (mask ushr 8).toByte(),
            mask.toByte()
        )
        return (InetAddress.getByAddress(bytes) as Inet4Address).hostAddress!!
    }

    fun getNetworkAddress(): String {
        val ip = BigInteger(1, inetAddress.address).toInt()
        val mask = -1 shl (32 - cidrPrefix)
        val network = ip and mask
        val bytes = byteArrayOf(
            (network ushr 24).toByte(),
            (network ushr 16).toByte(),
            (network ushr 8).toByte(),
            network.toByte()
        )
        return (InetAddress.getByAddress(bytes) as Inet4Address).hostAddress!!
    }

    fun getBroadcastAddress(): String {
        val ip = BigInteger(1, inetAddress.address).toInt()
        val mask = -1 shl (32 - cidrPrefix)
        val broadcast = ip or mask.inv()
        val bytes = byteArrayOf(
            (broadcast ushr 24).toByte(),
            (broadcast ushr 16).toByte(),
            (broadcast ushr 8).toByte(),
            broadcast.toByte()
        )
        return (InetAddress.getByAddress(bytes) as Inet4Address).hostAddress!!
    }

    fun getFirstHost(): String {
        val networkAddress = getNetworkAddress()
        if (cidrPrefix >= 31) {
            return networkAddress
        }
        val networkBytes = (InetAddress.getByName(networkAddress) as Inet4Address).address
        val networkInt = BigInteger(1, networkBytes).toInt()
        val firstHostInt = networkInt + 1
        val bytes = byteArrayOf(
            (firstHostInt ushr 24).toByte(),
            (firstHostInt ushr 16).toByte(),
            (firstHostInt ushr 8).toByte(),
            firstHostInt.toByte()
        )
        return (InetAddress.getByAddress(bytes) as Inet4Address).hostAddress!!
    }

    fun getLastHost(): String {
        val broadcastAddress = getBroadcastAddress()
        if (cidrPrefix >= 31) {
            return broadcastAddress
        }
        val broadcastBytes = (InetAddress.getByName(broadcastAddress) as Inet4Address).address
        val broadcastInt = BigInteger(1, broadcastBytes).toInt()
        val lastHostInt = broadcastInt - 1
        val bytes = byteArrayOf(
            (lastHostInt ushr 24).toByte(),
            (lastHostInt ushr 16).toByte(),
            (lastHostInt ushr 8).toByte(),
            lastHostInt.toByte()
        )
        return (InetAddress.getByAddress(bytes) as Inet4Address).hostAddress!!
    }

    fun getNumberOfHosts(): Long {
        return if (cidrPrefix >= 31) {
            0
        } else {
            (1L shl (32 - cidrPrefix)) - 2
        }
    }
}

class IPv6Helper(private val ipAddress: String, private val cidrPrefix: Int) {

    private val inetAddress: Inet6Address = InetAddress.getByName(ipAddress) as Inet6Address

    fun getNetmask(): String {
        val mask = BigInteger.valueOf(-1).shiftLeft(128 - cidrPrefix)
        val addressBytes = mask.toByteArray()
        val correctAddressBytes = if (addressBytes.size > 16) {
            addressBytes.copyOfRange(addressBytes.size - 16, addressBytes.size)
        } else {
            ByteArray(16 - addressBytes.size).plus(addressBytes)
        }
        return (InetAddress.getByAddress(correctAddressBytes) as Inet6Address).hostAddress!!
    }

    fun getNetworkAddress(): String {
        val ip = BigInteger(1, inetAddress.address)
        val mask = BigInteger.valueOf(-1).shiftLeft(128 - cidrPrefix)
        val network = ip.and(mask)
        val addressBytes = network.toByteArray()
        val correctAddressBytes = if (addressBytes.size > 16) {
            addressBytes.copyOfRange(addressBytes.size - 16, addressBytes.size)
        } else {
            ByteArray(16 - addressBytes.size).plus(addressBytes)
        }
        return (InetAddress.getByAddress(correctAddressBytes) as Inet6Address).hostAddress!!
    }

    fun getFirstHost(): String {
        return getNetworkAddress()
    }

    fun getLastHost(): String {
        val ip = BigInteger(1, inetAddress.address)
        val mask = BigInteger.valueOf(-1).shiftLeft(128 - cidrPrefix)
        val broadcast = ip.or(mask.not())
        val addressBytes = broadcast.toByteArray()
        val correctAddressBytes = if (addressBytes.size > 16) {
            addressBytes.copyOfRange(addressBytes.size - 16, addressBytes.size)
        } else {
            ByteArray(16 - addressBytes.size).plus(addressBytes)
        }
        return (InetAddress.getByAddress(correctAddressBytes) as Inet6Address).hostAddress!!
    }

    fun getNumberOfHosts(): BigInteger {
        return BigInteger.valueOf(2).pow(128 - cidrPrefix)
    }
}