//go:build daita
// +build daita

package device

import (
	"encoding/binary"
	"sync"
	"time"
	"unsafe"
)

// #include <stdio.h>
// #include <stdlib.h>
// #include "../maybenot-ffi/maybenot.h"
// #cgo LDFLAGS: -L${SRCDIR}/../ -lmaybenot -lm
import "C"

type MaybenotDaita struct {
	events          chan Event
	eventsClosed    bool
	eventsCloseLock sync.RWMutex
	eventsCBuf      []C.MaybenotEvent
	actions         chan Action
	maybenot        *C.MaybenotFramework
	newActionsBuf   []C.MaybenotAction
	paddingQueue    map[uint64]*time.Timer   // Map from machine to queued padding packets
	machineTimers   map[uint64]*MachineTimer // Map from machine to machine timer
	logger          *Logger
	stopping        sync.WaitGroup // waitgroup for runEventLoop and HandleDaitaActions
}

type MachineTimer struct {
	completeAt time.Time
	timer      *time.Timer
}

func (timer *MachineTimer) Stop() bool {
	return timer.timer.Stop()
}

func newMachineTimer(timeout time.Duration, callback func()) *MachineTimer {
	return &MachineTimer{
		completeAt: time.Now().Add(timeout),
		timer:      time.AfterFunc(timeout, callback),
	}
}

type Event struct {
	// The machine that generated the action that generated this event, if any.
	Machine   uint64
	EventType EventType
}

const (
	ERROR_GENERAL_FAILURE      = -1
	ERROR_INTERMITTENT_FAILURE = -2
)

// TODO: Consider using an Action interface, and defining MaybenotDaita.handleEvents as a method on
// that interface. Each Action enum variant could be an implenentation on that interface, that way
// we wouldn't have to flatten the Action enum into this ugly struct.
// Performance may be a concern though.
type Action struct {
	ActionType C.MaybenotAction_Tag

	// The maybenot machine that generated the action.
	// Should be propagated back by events generated by this action.
	Machine uint64

	// -- The fields below may or may not be used depending on ActionType --

	// Used for ActionTypes: Cancel
	Timer C.MaybenotTimer

	// The time at which the action should be performed
	// Used for ActionTypes: SendPadding, BlockOutgoing
	Timeout time.Duration

	// Used for ActionTypes: BlockOutgoing, UpdateTimer
	Duration time.Duration

	// Used for ActionTypes: SendPadding, BlockOutgoing, UpdateTimer
	Replace bool

	// Used for ActionType: SendPadding, BlockOutgoing
	Bypass bool
}

func (peer *Peer) EnableDaita(machines string, eventsCapacity uint, actionsCapacity uint, maxPaddingFrac float64, maxBlockingFrac float64) bool {
	peer.Lock()
	defer peer.Unlock()

	if !peer.isRunning.Load() {
		return false
	}

	if peer.daita != nil {
		peer.device.log.Errorf("Failed to activate DAITA as it is already active")
		return false
	}

	peer.device.log.Verbosef("Enabling DAITA for peer: %v", peer)

	mtu := peer.device.tun.mtu.Load()

	peer.device.log.Verbosef("MTU %v", mtu)
	var maybenot *C.MaybenotFramework
	c_machines := C.CString(machines)

	c_maxPaddingFrac := C.double(maxPaddingFrac)
	c_maxBlockingFrac := C.double(maxBlockingFrac)

	maybenot_result := C.maybenot_start(
		c_machines, c_maxPaddingFrac, c_maxBlockingFrac,
		&maybenot,
	)
	C.free(unsafe.Pointer(c_machines))

	if maybenot_result != 0 {
		peer.device.log.Errorf("Failed to initialize maybenot, code=%d", maybenot_result)
		return false
	}

	numMachines := C.maybenot_num_machines(maybenot)
	daita := MaybenotDaita{
		events:        make(chan Event, eventsCapacity),
		eventsCBuf:    make([]C.MaybenotEvent, eventsCapacity),
		eventsClosed:  false,
		maybenot:      maybenot,
		newActionsBuf: make([]C.MaybenotAction, numMachines),
		paddingQueue:  map[uint64]*time.Timer{},
		machineTimers: map[uint64]*MachineTimer{},
		logger:        peer.device.log,
	}

	daita.stopping.Add(1)
	go daita.runEventLoop(peer)

	peer.daita = &daita

	return true
}

// Stop the MaybenotDaita instance. It must not be used after calling this.
func (daita *MaybenotDaita) Close() {
	daita.logger.Verbosef("Waiting for DAITA routines to stop")

	daita.eventsCloseLock.Lock()
	close(daita.events)
	daita.eventsClosed = true
	daita.eventsCloseLock.Unlock()

	for _, timer := range daita.machineTimers {
		if timer.Stop() {
			daita.stopping.Done()
		}
	}

	for _, queuedPadding := range daita.paddingQueue {
		if queuedPadding.Stop() {
			daita.stopping.Done()
		}
	}
	daita.stopping.Wait()
	daita.logger.Verbosef("DAITA routines have stopped")
}

func (daita *MaybenotDaita) NormalReceived(peer *Peer) {
	daita.event(peer, NormalReceived, 0)
}

func (daita *MaybenotDaita) PaddingReceived(peer *Peer) {
	daita.event(peer, PaddingReceived, 0)
}

func (daita *MaybenotDaita) PaddingSent(peer *Peer, machine uint64) {
	daita.event(peer, PaddingSent, machine)
}

func (daita *MaybenotDaita) NormalSent(peer *Peer) {
	daita.event(peer, NormalSent, 0)
}

func (daita *MaybenotDaita) TunnelSent(peer *Peer) {
	daita.event(peer, TunnelSent, 0)
}

func (daita *MaybenotDaita) TunnelReceived(peer *Peer) {
	daita.event(peer, TunnelReceived, 0)
}

func (daita *MaybenotDaita) timerBegin(peer *Peer, machine uint64) {
	daita.event(peer, TimerBegin, machine)
}

func (daita *MaybenotDaita) timerEnd(peer *Peer, machine uint64) {
	daita.event(peer, TimerEnd, machine)
}

func (daita *MaybenotDaita) event(peer *Peer, eventType EventType, machine uint64) {
	if daita == nil {
		return
	}

	event := Event{
		Machine:   machine,
		EventType: eventType,
	}

	daita.eventsCloseLock.RLock()
	defer daita.eventsCloseLock.RUnlock()

	if daita.eventsClosed {
		return
	}

	select {
	case daita.events <- event:
	default:
		peer.device.log.Verbosef("Dropped DAITA event %v due to full buffer", event.EventType)
	}
}

func injectPadding(action Action, peer *Peer) {
	if action.ActionType != C.MaybenotAction_SendPadding {
		peer.device.log.Errorf("Got unknown action type %v", action.ActionType)
		return
	}

	if action.Replace && peer.HasReplaceablePackets() {
		peer.ReplacedPacketsInc()
		peer.daita.PaddingSent(peer, action.Machine)
		return
	}

	elem := peer.device.NewOutboundElement()
	elem.daitaPadding = true

	// All packets are MTU-sized when DAITA is enabled
	size := uint16(peer.device.tun.mtu.Load())

	elem.packet = elem.buffer[MessageTransportHeaderSize : MessageTransportHeaderSize+int(size)]
	elem.packet[0] = DaitaPaddingMarker
	binary.BigEndian.PutUint16(elem.packet[DaitaOffsetTotalLength:DaitaOffsetTotalLength+2], size)

	if peer.isRunning.Load() {
		peer.StagePacket(elem)
		elem = nil
		peer.SendStagedPackets()

		peer.daita.PaddingSent(peer, action.Machine)
	}
}

func (daita *MaybenotDaita) runEventLoop(peer *Peer) {
	defer func() {
		C.maybenot_stop(daita.maybenot)
		daita.stopping.Done()
		daita.logger.Verbosef("%v - DAITA: event handler - stopped", peer)
	}()

	events := make([]Event, len(daita.events))

	for {
		events = events[:0]

		event, more := <-daita.events
		if !more {
			return
		}

		events = append(events, event)

		// Drain remaining events
	HandleEvents:
		for {
			select {
			case event, more := <-daita.events:
				if !more {
					daita.logger.Verbosef("%v - DAITA: dropping %d unhandled events", peer, len(events))
					return
				}
				events = append(events, event)

				// Make sure not to exceed the maximum C buffer capacity
				if len(events) >= len(daita.eventsCBuf) {
					break HandleEvents
				}
			default:
				break HandleEvents
			}
		}

		daita.handleEvents(events, peer)
	}
}

func (daita *MaybenotDaita) handleEvents(event []Event, peer *Peer) {
	for _, cAction := range daita.maybenotEventToActions(event) {
		action := cActionToGo(cAction)

		switch action.ActionType {
		case C.MaybenotAction_Cancel:
			machine := action.Machine

			switch action.Timer {
			case C.MaybenotTimer_Action:
				daita.stopPaddingTimer(machine)
			case C.MaybenotTimer_Internal:
				daita.stopMachineTimer(machine)
			case C.MaybenotTimer_All:
				daita.stopMachineTimer(machine)
				daita.stopPaddingTimer(machine)
			}

		case C.MaybenotAction_SendPadding:
			// Check if a padding packet was already queued for the machine
			// If so, try to cancel it
			timer, paddingWasQueued := daita.paddingQueue[action.Machine]
			// If no padding was queued, or the action fire before we manage to
			// cancel it, we need to increment the wait group again
			if !paddingWasQueued || !timer.Stop() {
				daita.stopping.Add(1)
			}

			daita.paddingQueue[action.Machine] =
				time.AfterFunc(action.Timeout, func() {
					defer daita.stopping.Done()
					injectPadding(action, peer)
				})
		case C.MaybenotAction_BlockOutgoing:
			// TODO: implement BlockOutgoing
			daita.logger.Errorf("ignoring BlockOutgoing action, unimplemented")
		case C.MaybenotAction_UpdateTimer:
			// Check if a padding packet was already queued for the machine
			timer, timerWasQueued := daita.machineTimers[action.Machine]

			var startNewTimer bool
			if !timerWasQueued || action.Replace {
				// Always start timer if it does not exist or if the replace flag is set
				startNewTimer = true
			} else {
				now := time.Now()
				// Replace timer if it (should have) already fired (completeAt - now is negative)
				// or in general if the action duration is greater than the time left
				startNewTimer = action.Duration > timer.completeAt.Sub(now)
			}

			// Replace or start new timer
			if startNewTimer {
				if !timerWasQueued || !timer.Stop() {
					// If no timer was cancelled, increment wait group
					// This is because the timer will decrement the wait group when it fires
					daita.stopping.Add(1)
				}

				daita.timerBegin(peer, action.Machine)
				daita.machineTimers[action.Machine] =
					newMachineTimer(action.Duration, func() {
						// Decrement wait group counter
						defer daita.stopping.Done()

						daita.timerEnd(peer, action.Machine)
					})
			}
		}
	}
}

func (daita *MaybenotDaita) stopMachineTimer(machine uint64) {
	if timer, ok := daita.machineTimers[machine]; ok {
		if timer.Stop() {
			daita.stopping.Done()
		}
	}
}

func (daita *MaybenotDaita) stopPaddingTimer(machine uint64) {
	if queuedPadding, ok := daita.paddingQueue[machine]; ok {
		if queuedPadding.Stop() {
			daita.stopping.Done()
		}
	}
}

func (daita *MaybenotDaita) maybenotEventToActions(events []Event) []C.MaybenotAction {
	for i := 0; i < len(events); i++ {
		daita.eventsCBuf[i] = C.MaybenotEvent{
			machine:    C.uintptr_t(events[i].Machine),
			event_type: C.uint32_t(events[i].EventType),
		}
	}

	var actionsWritten C.uintptr_t

	firstElem := (*C.MaybenotEvent)(unsafe.SliceData(daita.eventsCBuf))
	result := C.maybenot_on_events(daita.maybenot, firstElem, C.uintptr_t(len(events)), &daita.newActionsBuf[0], &actionsWritten)
	if result != 0 {
		daita.logger.Errorf("Failed to handle event as it was a null pointer")
		return nil
	}

	newActions := daita.newActionsBuf[:actionsWritten]
	return newActions
}

func cActionToGo(action_c C.MaybenotAction) Action {
	switch action_c.tag {
	case C.MaybenotAction_Cancel:
		body := (*C.MaybenotAction_Cancel_Body)(unsafe.Pointer(&action_c.anon0[0]))
		return Action{
			ActionType: action_c.tag,
			Machine:    uint64(body.machine),
			Timer:      body.timer,
		}

	case C.MaybenotAction_SendPadding:
		body := (*C.MaybenotAction_SendPadding_Body)(unsafe.Pointer(&action_c.anon0[0]))
		return Action{
			ActionType: action_c.tag,
			Machine:    uint64(body.machine),
			Timeout:    maybenotDurationToGoDuration(body.timeout),
			Replace:    bool(body.replace),
			Bypass:     bool(body.bypass),
		}

	case C.MaybenotAction_BlockOutgoing:
		body := (*C.MaybenotAction_BlockOutgoing_Body)(unsafe.Pointer(&action_c.anon0[0]))
		return Action{
			ActionType: action_c.tag,
			Machine:    uint64(body.machine),
			Timeout:    maybenotDurationToGoDuration(body.timeout),
			Replace:    bool(body.replace),
			Bypass:     bool(body.bypass),
			Duration:   maybenotDurationToGoDuration(body.duration),
		}

	case C.MaybenotAction_UpdateTimer:
		body := (*C.MaybenotAction_UpdateTimer_Body)(unsafe.Pointer(&action_c.anon0[0]))
		return Action{
			ActionType: action_c.tag,
			Machine:    uint64(body.machine),
			Duration:   maybenotDurationToGoDuration(body.duration),
			Replace:    bool(body.replace),
		}

	default:
		panic("Unknown C.MaybenotAction tag")
	}
}

func maybenotDurationToGoDuration(duration C.MaybenotDuration) time.Duration {
	// let's just assume this is fine...
	nanoseconds := uint64(duration.secs)*1_000_000_000 + uint64(duration.nanos)
	return time.Duration(nanoseconds)
}
