// Tests for dnss in GRPC modes.
package grpc

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"flag"
	"fmt"
	"io/ioutil"
	"math/big"
	"net"
	"os"
	"testing"
	"time"

	"blitiri.com.ar/go/dnss/internal/dnstox"
	"blitiri.com.ar/go/dnss/internal/grpctodns"
	"blitiri.com.ar/go/dnss/testing/util"

	"github.com/golang/glog"
	"github.com/miekg/dns"
)

// Addresses to use for testing. These will be picked at initialization time,
// see init().
var dnsToGrpcAddr, grpcToDnsAddr, dnsSrvAddr string

func init() {
	dnsToGrpcAddr = util.GetFreePort()
	grpcToDnsAddr = util.GetFreePort()
	dnsSrvAddr = util.GetFreePort()
}

//
// === Tests ===
//

func dnsQuery(conn *dns.Conn) error {
	m := &dns.Msg{}
	m.SetQuestion("ca.chai.", dns.TypeMX)

	conn.WriteMsg(m)
	_, err := conn.ReadMsg()
	return err
}

func TestSimple(t *testing.T) {
	conn, err := dns.DialTimeout("udp", dnsToGrpcAddr, 1*time.Second)
	if err != nil {
		t.Fatalf("dns.Dial error: %v", err)
	}
	defer conn.Close()

	err = dnsQuery(conn)
	if err != nil {
		t.Errorf("dns query returned error: %v", err)
	}
}

//
// === Benchmarks ===
//

func manyDNSQueries(b *testing.B, addr string) {
	conn, err := dns.DialTimeout("udp", addr, 1*time.Second)
	if err != nil {
		b.Fatalf("dns.Dial error: %v", err)
	}
	defer conn.Close()

	for i := 0; i < b.N; i++ {
		err = dnsQuery(conn)
		if err != nil {
			b.Errorf("dns query returned error: %v", err)
		}
	}
}

func BenchmarkGRPCDirect(b *testing.B) {
	manyDNSQueries(b, dnsSrvAddr)
}

func BenchmarkGRPCWithProxy(b *testing.B) {
	manyDNSQueries(b, dnsToGrpcAddr)
}

//
// === Test environment ===
//

// dnsServer implements a DNS server for testing.
// It always gives the same reply, regardless of the query.
type dnsServer struct {
	Addr     string
	srv      *dns.Server
	answerRR dns.RR
}

func (s *dnsServer) Handler(w dns.ResponseWriter, r *dns.Msg) {
	// Building the reply (and setting the corresponding id) is cheaper than
	// copying a "master" message.
	m := &dns.Msg{}
	m.Id = r.Id
	m.Response = true
	m.Authoritative = true
	m.Rcode = dns.RcodeSuccess
	m.Answer = append(m.Answer, s.answerRR)
	w.WriteMsg(m)
}

func (s *dnsServer) ListenAndServe() {
	var err error

	s.answerRR, err = dns.NewRR("test.blah A 1.2.3.4")
	if err != nil {
		panic(err)
	}

	s.srv = &dns.Server{
		Addr:    s.Addr,
		Net:     "udp",
		Handler: dns.HandlerFunc(s.Handler),
	}
	err = s.srv.ListenAndServe()
	if err != nil {
		panic(err)
	}
}

// generateCert generates a new, INSECURE self-signed certificate and writes
// it to a pair of (cert.pem, key.pem) files to the given path.
// Note the certificate is only useful for testing purposes.
func generateCert(path string) error {
	tmpl := x509.Certificate{
		SerialNumber: big.NewInt(1234),
		Subject: pkix.Name{
			Organization: []string{"dnss testing"},
		},

		IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},

		NotBefore: time.Now(),
		NotAfter:  time.Now().Add(30 * time.Minute),

		KeyUsage: x509.KeyUsageKeyEncipherment |
			x509.KeyUsageDigitalSignature |
			x509.KeyUsageCertSign,

		BasicConstraintsValid: true,
		IsCA: true,
	}

	priv, err := rsa.GenerateKey(rand.Reader, 1024)
	if err != nil {
		return err
	}

	derBytes, err := x509.CreateCertificate(
		rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
	if err != nil {
		return err
	}

	certOut, err := os.Create(path + "/cert.pem")
	if err != nil {
		return err
	}
	defer certOut.Close()
	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})

	keyOut, err := os.OpenFile(
		path+"/key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
	if err != nil {
		return err
	}
	defer keyOut.Close()

	block := &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(priv),
	}
	pem.Encode(keyOut, block)
	return nil
}

// realMain is the real main function, which returns the value to pass to
// os.Exit(). We have to do this so we can use defer.
func realMain(m *testing.M) int {
	flag.Parse()
	defer glog.Flush()

	// Generate certificates in a temporary directory.
	tmpDir, err := ioutil.TempDir("", "dnss_test:")
	if err != nil {
		fmt.Printf("Failed to create temp dir: %v\n", tmpDir)
		return 1
	}
	defer os.RemoveAll(tmpDir)

	err = generateCert(tmpDir)
	if err != nil {
		fmt.Printf("Failed to generate cert for testing: %v\n", err)
		return 1
	}

	// DNS to GRPC server.
	gr := dnstox.NewGRPCResolver(grpcToDnsAddr, tmpDir+"/cert.pem")
	cr := dnstox.NewCachingResolver(gr)
	dtg := dnstox.New(dnsToGrpcAddr, cr, "")
	go dtg.ListenAndServe()

	// GRPC to DNS server.
	gtd := &grpctodns.Server{
		Addr:     grpcToDnsAddr,
		Upstream: dnsSrvAddr,
		CertFile: tmpDir + "/cert.pem",
		KeyFile:  tmpDir + "/key.pem",
	}
	go gtd.ListenAndServe()

	// DNS test server.
	dnsSrv := dnsServer{
		Addr: dnsSrvAddr,
	}
	go dnsSrv.ListenAndServe()

	// Wait for the servers to start up.
	err = util.WaitForDNSServer(dnsToGrpcAddr)
	if err != nil {
		fmt.Printf("Error waiting for the test servers to start: %v\n", err)
		fmt.Printf("Check the INFO logs for more details\n")
		return 1
	}

	return m.Run()
}

func TestMain(m *testing.M) {
	os.Exit(realMain(m))
}
