package de.gultsch.minidns;

import android.content.Context;
import android.net.ConnectivityManager;
import android.net.LinkProperties;
import android.net.Network;
import android.os.Build;
import android.util.Log;
import androidx.annotation.NonNull;
import androidx.collection.LruCache;
import com.google.common.base.Strings;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import eu.siacs.conversations.Config;
import java.io.IOException;
import java.net.InetAddress;
import java.time.Duration;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.minidns.AbstractDnsClient;
import org.minidns.dnsmessage.DnsMessage;
import org.minidns.dnsmessage.Question;
import org.minidns.dnsqueryresult.DnsQueryResult;
import org.minidns.dnsqueryresult.StandardDnsQueryResult;
import org.minidns.record.Data;
import org.minidns.record.Record;

public class AndroidDNSClient extends AbstractDnsClient {

    private static final long DNS_MAX_TTL = 86_400L;

    private static final ScheduledExecutorService SCHEDULED_EXECUTOR_SERVICE =
            Executors.newSingleThreadScheduledExecutor();

    private static final LruCache<QuestionServerTuple, DnsMessage> QUERY_CACHE =
            new LruCache<>(1024);
    private final Context context;
    private final NetworkDataSource networkDataSource = new NetworkDataSource();
    private boolean askForDnssec = false;

    public AndroidDNSClient(final Context context) {
        super();
        this.context = context;
    }

    private static String getPrivateDnsServerName(final LinkProperties linkProperties) {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
            return linkProperties.getPrivateDnsServerName();
        } else {
            return null;
        }
    }

    private static boolean isPrivateDnsActive(final LinkProperties linkProperties) {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
            return linkProperties.isPrivateDnsActive();
        } else {
            return false;
        }
    }

    @Override
    protected DnsMessage.Builder newQuestion(final DnsMessage.Builder message) {
        message.setRecursionDesired(true);
        message.getEdnsBuilder()
                .setUdpPayloadSize(networkDataSource.getUdpPayloadSize())
                .setDnssecOk(askForDnssec);
        return message;
    }

    @Override
    protected DnsQueryResult query(final DnsMessage.Builder queryBuilder) throws IOException {
        throw new IOException("Not implemented");
    }

    public ListenableFuture<DnsQueryResult> queryAsFuture(final Question q) {
        final DnsMessage.Builder query = buildMessage(q);
        return queryAsFuture(query);
    }

    protected ListenableFuture<DnsQueryResult> queryAsFuture(
            final DnsMessage.Builder queryBuilder) {
        final var dnsServers = getDNSServers();
        final DnsMessage question = newQuestion(queryBuilder).build();
        final var rawFuture = queryAsFuture(question, new LinkedList<>(dnsServers));
        // allow for enough time to hit 2 servers over UDP and TCP one after another
        return Futures.withTimeout(
                rawFuture,
                Math.round(DNSSocket.QUERY_TIMEOUT * 4.2f),
                TimeUnit.MILLISECONDS,
                SCHEDULED_EXECUTOR_SERVICE);
    }

    protected ListenableFuture<DnsQueryResult> queryAsFuture(
            final DnsMessage question, final Queue<DNSServer> dnsServers) {
        if (dnsServers.isEmpty()) {
            return Futures.immediateFailedFuture(
                    new IllegalStateException("Tried all DNS servers"));
        }
        final var dnsServer = dnsServers.poll();
        if (dnsServer == null) {
            return Futures.immediateFailedFuture(new IllegalStateException("DNS Server was null"));
        }
        final QuestionServerTuple cacheKey = new QuestionServerTuple(dnsServer, question);
        final DnsMessage cachedResponse = queryCache(cacheKey);
        if (cachedResponse != null) {
            return Futures.immediateFuture(new CachedDnsQueryResult(question, cachedResponse));
        }
        final var future = this.networkDataSource.query(question, dnsServer);
        final var transformedFuture =
                Futures.transform(
                        future,
                        result -> {
                            if (result == null || result.response == null) {
                                throw new IllegalStateException("Result or response was null");
                            }
                            final var response = result.response;
                            if (response.responseCode == DnsMessage.RESPONSE_CODE.NO_ERROR
                                    || response.responseCode
                                            == DnsMessage.RESPONSE_CODE.NX_DOMAIN) {
                                return new StandardDnsQueryResult(
                                        dnsServer.inetAddress,
                                        dnsServer.port,
                                        result.queryMethod,
                                        question,
                                        response);
                            }
                            throw new IllegalStateException("Received error response code");
                        },
                        MoreExecutors.directExecutor());

        final var caughtFuture =
                Futures.catchingAsync(
                        transformedFuture,
                        Throwable.class,
                        t -> {
                            if (dnsServers.isEmpty()) {
                                return Futures.immediateFailedFuture(t);
                            }
                            return queryAsFuture(question, dnsServers);
                        },
                        MoreExecutors.directExecutor());

        return Futures.transform(
                caughtFuture,
                qr -> {
                    cacheQuery(cacheKey, qr.response);
                    return qr;
                },
                MoreExecutors.directExecutor());
    }

    final DnsMessage.Builder buildMessage(final Question question) {
        final DnsMessage.Builder message = DnsMessage.builder();
        message.setQuestion(question);
        message.setId(random.nextInt());
        return newQuestion(message);
    }

    public boolean isAskForDnssec() {
        return askForDnssec;
    }

    public void setAskForDnssec(boolean askForDnssec) {
        this.askForDnssec = askForDnssec;
    }

    private List<DNSServer> getDNSServers() {
        final var c = this.context;
        if (c == null) {
            Log.e(Config.LOGTAG, "no DNS servers found. Context not ready");
            return Collections.emptyList();
        }
        final ConnectivityManager connectivityManager =
                c.getSystemService(ConnectivityManager.class);
        if (connectivityManager == null) {
            Log.w(Config.LOGTAG, "no DNS servers found. ConnectivityManager was null");
            return Collections.emptyList();
        }
        final Network activeNetwork = connectivityManager.getActiveNetwork();
        final List<DNSServer> activeDnsServers =
                activeNetwork == null
                        ? Collections.emptyList()
                        : getDNSServers(connectivityManager, new Network[] {activeNetwork});
        if (activeDnsServers.isEmpty()) {
            Log.d(Config.LOGTAG, "no DNS servers on active networks. looking at all networks");
            return getDNSServers(connectivityManager, connectivityManager.getAllNetworks());
        } else {
            return activeDnsServers;
        }
    }

    private List<DNSServer> getDNSServers(
            @NonNull final ConnectivityManager connectivityManager,
            @NonNull final Network[] networks) {
        final ImmutableList.Builder<DNSServer> dnsServerBuilder = new ImmutableList.Builder<>();
        for (final Network network : networks) {
            final LinkProperties linkProperties = connectivityManager.getLinkProperties(network);
            if (linkProperties == null) {
                continue;
            }
            final String privateDnsServerName = getPrivateDnsServerName(linkProperties);
            if (Strings.isNullOrEmpty(privateDnsServerName)) {
                final boolean isPrivateDns = isPrivateDnsActive(linkProperties);
                for (final InetAddress dnsServer : linkProperties.getDnsServers()) {
                    if (isPrivateDns) {
                        dnsServerBuilder.add(new DNSServer(dnsServer, Transport.TLS));
                    } else {
                        dnsServerBuilder.add(new DNSServer(dnsServer));
                    }
                }
            } else {
                dnsServerBuilder.add(new DNSServer(privateDnsServerName, Transport.TLS));
            }
        }
        return dnsServerBuilder.build();
    }

    private DnsMessage queryCache(final QuestionServerTuple key) {
        final DnsMessage cachedResponse;
        synchronized (QUERY_CACHE) {
            cachedResponse = QUERY_CACHE.get(key);
            if (cachedResponse == null) {
                return null;
            }
            final long expiresIn = expiresIn(cachedResponse);
            if (expiresIn < 0) {
                QUERY_CACHE.remove(key);
                return null;
            }
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
                Log.d(
                        Config.LOGTAG,
                        "DNS query came from cache. expires in " + Duration.ofMillis(expiresIn));
            }
        }
        return cachedResponse;
    }

    private void cacheQuery(final QuestionServerTuple key, final DnsMessage response) {
        if (response.receiveTimestamp <= 0) {
            return;
        }
        synchronized (QUERY_CACHE) {
            QUERY_CACHE.put(key, response);
        }
    }

    private static long ttl(final DnsMessage dnsMessage) {
        final List<Record<? extends Data>> answerSection = dnsMessage.answerSection;
        if (answerSection == null || answerSection.isEmpty()) {
            final List<Record<? extends Data>> authoritySection = dnsMessage.authoritySection;
            if (authoritySection == null || authoritySection.isEmpty()) {
                return 0;
            } else {
                return Collections.min(Collections2.transform(authoritySection, d -> d.ttl));
            }

        } else {
            return Collections.min(Collections2.transform(answerSection, d -> d.ttl));
        }
    }

    private static long expiresAt(final DnsMessage dnsMessage) {
        return dnsMessage.receiveTimestamp + (Math.min(DNS_MAX_TTL, ttl(dnsMessage)) * 1000L);
    }

    private static long expiresIn(final DnsMessage dnsMessage) {
        return expiresAt(dnsMessage) - System.currentTimeMillis();
    }

    private record QuestionServerTuple(DNSServer dnsServer, DnsMessage question) {
        private QuestionServerTuple(final DNSServer dnsServer, final DnsMessage question) {
            this.dnsServer = dnsServer;
            this.question = question.asNormalizedVersion();
        }
    }

    public static class CachedDnsQueryResult extends DnsQueryResult {

        private CachedDnsQueryResult(final DnsMessage query, final DnsMessage response) {
            super(QueryMethod.cachedDirect, query, response);
        }
    }
}
