// Copyright 2016-2024 Mullvad VPN AB. All Rights Reserved.
// Copyright 2024 Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: GPL-3.0-only

use crate::RequiredRoute;
pub use default_route_monitor::EventType;
use futures::{
    StreamExt,
    channel::{
        mpsc::{self, UnboundedReceiver, UnboundedSender},
        oneshot,
    },
};
pub use get_best_default_route::{InterfaceAndGateway, get_best_default_route};
use net::AddressFamily;
use nym_common::trace_err_chain;
use nym_windows::net;
pub use route_manager::{Callback, CallbackHandle, Route, RouteManagerInternal};
use std::{collections::HashSet, io, net::IpAddr};

mod default_route_monitor;
mod get_best_default_route;
mod route_manager;

/// Windows routing errors.
#[derive(thiserror::Error, Debug)]
pub enum Error {
    /// Failure to initialize route manager
    #[error("failed to start route manager")]
    FailedToStartManager,
    /// Attempt to use route manager that has been dropped
    #[error("cannot send message to route manager since it is down")]
    RouteManagerDown,
    /// Low level error caused by a failure to add to route table
    #[error("could not add route to route table")]
    AddToRouteTable(#[source] windows::core::Error),
    /// Low level error caused by failure to delete route from route table
    #[error("failed to delete applied routes")]
    DeleteFromRouteTable(#[source] windows::core::Error),
    /// GetIpForwardTable2 windows API call failed
    #[error("failed to retrieve the routing table")]
    GetIpForwardTableFailed(#[source] windows::core::Error),
    /// GetIfEntry2 windows API call failed
    #[error("failed to retrieve network interface entry")]
    GetIfEntryFailed(#[source] windows::core::Error),
    /// Low level error caused by failing to register the route callback
    #[error("attempt to register notify route change callback failed")]
    RegisterNotifyRouteCallback(#[source] windows::core::Error),
    /// Low level error caused by failing to register the ip interface callback
    #[error("attempt to register notify ip interface change callback failed")]
    RegisterNotifyIpInterfaceCallback(#[source] windows::core::Error),
    /// Low level error caused by failing to register the unicast ip address callback
    #[error("attempt to register notify unicast ip address change callback failed")]
    RegisterNotifyUnicastIpAddressCallback(#[source] windows::core::Error),
    /// Low level error caused by windows Adapters API
    #[error("windows adapter error")]
    Adapter(io::Error),
    /// High level error caused by a failure to clear the routes in the route manager.
    /// Contains the lower error
    #[error("failed to clear applied routes")]
    ClearRoutesFailed(Box<Error>),
    /// High level error caused by a failure to add routes in the route manager.
    /// Contains the lower error
    #[error("failed to add routes")]
    AddRoutesFailed(Box<Error>),
    /// Something went wrong when getting the mtu of the interface
    #[error("could not get the mtu of the interface")]
    GetMtu,
    /// The SI family was of an unexpected value
    #[error("the SI family was of an unexpected value")]
    InvalidSiFamily,
    /// Device name not found
    #[error("the device name was not found")]
    DeviceNameNotFound,
    /// No default route
    #[error("no default route found")]
    NoDefaultRoute,
    /// Conversion error between types
    #[error("conversion error")]
    Conversion,
    /// Could not find device gateway
    #[error("could not find device gateway")]
    DeviceGatewayNotFound,
    /// Could not get default route
    #[error("could not get default route")]
    GetDefaultRoute,
    /// Could not find device by name
    #[error("could not find device by name")]
    GetDeviceByName,
    /// Could not find device by gateway
    #[error("could not find device by gateway")]
    GetDeviceByGateway,
}

impl Error {
    /// Return whether retrying the operation that caused this error is likely to succeed.
    pub fn is_recoverable(&self) -> bool {
        matches!(self, Error::AddRoutesFailed(_))
    }
}

pub type Result<T> = std::result::Result<T, Error>;

/// Manages routes by calling into WinNet
#[derive(Debug, Clone)]
pub struct RouteManagerHandle {
    tx: UnboundedSender<RouteManagerCommand>,
}

pub enum RouteManagerCommand {
    AddRoutes(HashSet<RequiredRoute>, oneshot::Sender<Result<()>>),
    GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16>>),
    ClearRoutes,
    RegisterDefaultRouteChangeCallback(Callback, oneshot::Sender<CallbackHandle>),
    Shutdown(oneshot::Sender<()>),
}

impl RouteManagerHandle {
    /// Create a new route manager
    #[allow(clippy::unused_async)]
    pub async fn spawn() -> Result<Self> {
        let internal = RouteManagerInternal::new().map_err(|_| Error::FailedToStartManager)?;
        let (tx, rx) = mpsc::unbounded();
        let handle = Self { tx };
        tokio::spawn(RouteManagerHandle::run(rx, internal));

        Ok(handle)
    }

    /// Add a callback which will be called if the default route changes.
    pub async fn add_default_route_change_callback(
        &self,
        callback: Callback,
    ) -> Result<CallbackHandle> {
        let (response_tx, response_rx) = oneshot::channel();
        self.tx
            .unbounded_send(RouteManagerCommand::RegisterDefaultRouteChangeCallback(
                callback,
                response_tx,
            ))
            .map_err(|_| Error::RouteManagerDown)?;
        response_rx.await.map_err(|_| Error::RouteManagerDown)
    }

    /// Applies the given routes while the route manager is running.
    pub async fn add_routes(&self, routes: HashSet<RequiredRoute>) -> Result<()> {
        let (response_tx, response_rx) = oneshot::channel();
        self.tx
            .unbounded_send(RouteManagerCommand::AddRoutes(routes, response_tx))
            .map_err(|_| Error::RouteManagerDown)?;
        response_rx.await.map_err(|_| Error::RouteManagerDown)?
    }

    /// Retrieve MTU for the given destination/route.
    pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> {
        let (response_tx, response_rx) = oneshot::channel();
        self.tx
            .unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx))
            .map_err(|_| Error::RouteManagerDown)?;
        response_rx.await.map_err(|_| Error::RouteManagerDown)?
    }

    /// Stop the routing manager actor and revert all changes to routing
    pub async fn stop(&self) {
        let (result_tx, result_rx) = oneshot::channel();
        _ = self
            .tx
            .unbounded_send(RouteManagerCommand::Shutdown(result_tx));
        _ = result_rx.await;
    }

    /// Removes all routes previously applied in [`RouteManagerInternal::add_routes`].
    pub fn clear_routes(&self) -> Result<()> {
        self.tx
            .unbounded_send(RouteManagerCommand::ClearRoutes)
            .map_err(|_| Error::RouteManagerDown)
    }

    async fn run(
        mut manage_rx: UnboundedReceiver<RouteManagerCommand>,
        mut internal: RouteManagerInternal,
    ) {
        while let Some(command) = manage_rx.next().await {
            match command {
                RouteManagerCommand::AddRoutes(routes, tx) => {
                    let routes: Vec<_> = routes
                        .into_iter()
                        .map(|route| Route {
                            network: route.prefix,
                            node: route.node,
                        })
                        .collect();

                    let _ = tx.send(
                        internal
                            .add_routes(routes)
                            .map_err(|e| Error::AddRoutesFailed(Box::new(e))),
                    );
                }
                RouteManagerCommand::GetMtuForRoute(ip, tx) => {
                    let addr_family = if ip.is_ipv4() {
                        AddressFamily::Ipv4
                    } else {
                        AddressFamily::Ipv6
                    };
                    let res = match get_mtu_for_route(addr_family) {
                        Ok(Some(mtu)) => Ok(mtu),
                        Ok(None) => Err(Error::GetMtu),
                        Err(e) => Err(e),
                    };
                    let _ = tx.send(res);
                }
                RouteManagerCommand::ClearRoutes => {
                    if let Err(e) = internal.delete_applied_routes() {
                        trace_err_chain!(e, "Could not clear routes");
                    }
                }
                RouteManagerCommand::RegisterDefaultRouteChangeCallback(callback, tx) => {
                    let _ = tx.send(internal.register_default_route_changed_callback(callback));
                }
                RouteManagerCommand::Shutdown(tx) => {
                    drop(internal);
                    let _ = tx.send(());
                    break;
                }
            }
        }
    }
}

fn get_mtu_for_route(addr_family: AddressFamily) -> Result<Option<u16>> {
    match get_best_default_route(addr_family) {
        Ok(Some(route)) => {
            let interface_row =
                net::get_ip_interface_entry(addr_family, &route.iface).map_err(|e| {
                    tracing::error!("Could not get ip interface entry: {}", e);
                    Error::GetMtu
                })?;
            let mtu = interface_row.NlMtu;
            let mtu = u16::try_from(mtu).map_err(|_| Error::GetMtu)?;
            Ok(Some(mtu))
        }
        Ok(None) => Ok(None),
        Err(e) => {
            tracing::error!("Could not get best default route: {}", e);
            Err(Error::GetMtu)
        }
    }
}
