/*
    This file is part of NetGuard.

    NetGuard is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    NetGuard is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with NetGuard.  If not, see <http://www.gnu.org/licenses/>.

    Copyright 2015-2024 by Marcel Bokhorst (M66B)
*/

#include <jni.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <time.h>
#include <unistd.h>
#include <pthread.h>
#include <setjmp.h>
#include <errno.h>
#include <fcntl.h>
#include <dirent.h>
#include <poll.h>
#include <sys/types.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <dlfcn.h>
#include <sys/stat.h>
#include <sys/resource.h>

#include <netdb.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <netinet/in6.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/udp.h>
#include <netinet/tcp.h>
#include <netinet/ip_icmp.h>
#include <netinet/icmp6.h>

#include <android/log.h>
#include <sys/system_properties.h>

#define TAG "NetGuard.JNI"

#define EPOLL_TIMEOUT 3600
#define EPOLL_EVENTS 20
#define EPOLL_MIN_CHECK 100

#define TUN_YIELD 10 // packets

#define ICMP4_MAXMSG (IP_MAXPACKET - 20 - 8) // bytes (socket)
#define ICMP6_MAXMSG (IPV6_MAXPACKET - 40 - 8) // bytes (socket)
#define UDP4_MAXMSG (IP_MAXPACKET - 20 - 8) // bytes (socket)
#define UDP6_MAXMSG (IPV6_MAXPACKET - 40 - 8) // bytes (socket)

#define ICMP_TIMEOUT 5 // seconds

#define UDP_TIMEOUT_53 15 // seconds
#define UDP_TIMEOUT_ANY 300 // seconds
#define UDP_KEEP_TIMEOUT 60 // seconds
#define UDP_YIELD 10 // packets

#define TCP_INIT_TIMEOUT 20 // seconds ~net.inet.tcp.keepinit
#define TCP_IDLE_TIMEOUT 3600 // seconds ~net.inet.tcp.keepidle
#define TCP_CLOSE_TIMEOUT 20 // seconds
#define TCP_KEEP_TIMEOUT 300 // seconds

#define SESSION_LIMIT 40 // percent
#define SESSION_MAX (1024 * SESSION_LIMIT / 100) // number

#define SEND_BUF_DEFAULT 163840 // bytes


#define SOCKS5_NONE 1
#define SOCKS5_HELLO 2
#define SOCKS5_AUTH 3
#define SOCKS5_CONNECT 4
#define SOCKS5_CONNECTED 5

struct context {
    pthread_mutex_t lock;
    int pipefds[2];
    int stopping;
    int sdk;
    struct ng_session *ng_session;    char dns_server_v4[INET_ADDRSTRLEN];    char dns_server_v6[INET6_ADDRSTRLEN];
};

struct arguments {
    JNIEnv *env;
    jobject instance;
    int tun;
    jboolean fwd53;
    jint rcode;
    struct context *ctx;
};

struct allowed {
    char raddr[INET6_ADDRSTRLEN + 1];
    uint16_t rport; // host notation
};

struct segment {
    uint32_t seq;
    uint16_t len;
    uint16_t sent;
    int psh;
    uint8_t *data;
    struct segment *next;
};

struct icmp_session {
    time_t time;
    jint uid;
    int version;

    union {
        __be32 ip4; // network notation
        struct in6_addr ip6;
    } saddr;

    union {
        __be32 ip4; // network notation
        struct in6_addr ip6;
    } daddr;

    uint16_t id;

    uint8_t stop;
};

#define UDP_ACTIVE 0
#define UDP_FINISHING 1
#define UDP_CLOSED 2

struct udp_session {
    time_t time;
    jint uid;
    int version;
    uint16_t mss;

    uint64_t sent;
    uint64_t received;

    union {
        __be32 ip4; // network notation
        struct in6_addr ip6;
    } saddr;
    __be16 source; // network notation

    union {
        __be32 ip4; // network notation
        struct in6_addr ip6;
    } daddr;
    __be16 dest; // network notation

    uint8_t state;
};

struct tcp_session {
    jint uid;
    time_t time;
    int version;
    uint16_t mss;
    uint8_t recv_scale;
    uint8_t send_scale;
    uint32_t recv_window; // host notation, scaled
    uint32_t send_window; // host notation, scaled
    uint16_t unconfirmed; // packets

    uint32_t remote_seq; // confirmed bytes received, host notation
    uint32_t local_seq; // confirmed bytes sent, host notation
    uint32_t remote_start;
    uint32_t local_start;

    uint32_t acked; // host notation
    long long last_keep_alive;

    uint64_t sent;
    uint64_t received;

    union {
        __be32 ip4; // network notation
        struct in6_addr ip6;
    } saddr;
    __be16 source; // network notation

    union {
        __be32 ip4; // network notation
        struct in6_addr ip6;
    } daddr;
    __be16 dest; // network notation

    uint8_t state;
    uint8_t socks5;
    struct segment *forward;
};

struct ng_session {
    uint8_t protocol;
    union {
        struct icmp_session icmp;
        struct udp_session udp;
        struct tcp_session tcp;
    };
    jint socket;
    struct epoll_event ev;
    struct ng_session *next;
};

// IPv6

struct ip6_hdr_pseudo {
    struct in6_addr ip6ph_src;
    struct in6_addr ip6ph_dst;
    u_int32_t ip6ph_len;
    u_int8_t ip6ph_zero[3];
    u_int8_t ip6ph_nxt;
} __packed;

#define LINKTYPE_RAW 101

// TLS

#define TLS_SNI_LENGTH 255


typedef struct dns_rr {
    __be16 qname_ptr;
    __be16 qtype;
    __be16 qclass;
    __be32 ttl;
    __be16 rdlength;
} __packed dns_rr;

// DHCP

#define DHCP_OPTION_MAGIC_NUMBER (0x63825363)

typedef struct dhcp_packet {
    uint8_t opcode;
    uint8_t htype;
    uint8_t hlen;
    uint8_t hops;
    uint32_t xid;
    uint16_t secs;
    uint16_t flags;
    uint32_t ciaddr;
    uint32_t yiaddr;
    uint32_t siaddr;
    uint32_t giaddr;
    uint8_t chaddr[16];
    uint8_t sname[64];
    uint8_t file[128];
    uint32_t option_format;
} __packed dhcp_packet;

typedef struct dhcp_option {
    uint8_t code;
    uint8_t length;
} __packed dhcp_option;

void *handle_events(void *a);


void clear(struct context *ctx);

int check_icmp_session(const struct arguments *args,
                       struct ng_session *s,
                       int sessions, int maxsessions);

int check_udp_session(const struct arguments *args,
                      struct ng_session *s,
                      int sessions, int maxsessions);

int check_tcp_session(const struct arguments *args,
                      struct ng_session *s,
                      int sessions, int maxsessions);

int monitor_tcp_session(const struct arguments *args, struct ng_session *s, int epoll_fd);

int get_icmp_timeout(const struct icmp_session *u, int sessions, int maxsessions);

int get_udp_timeout(const struct udp_session *u, int sessions, int maxsessions);

int get_tcp_timeout(const struct tcp_session *t, int sessions, int maxsessions);

uint16_t get_mtu();

uint16_t get_default_mss(int version);

int check_tun(const struct arguments *args,
              const struct epoll_event *ev,
              const int epoll_fd,
              int sessions, int maxsessions);

void check_icmp_socket(const struct arguments *args, const struct epoll_event *ev);

void check_udp_socket(const struct arguments *args, const struct epoll_event *ev);

uint32_t get_send_window(const struct tcp_session *cur);

uint32_t get_receive_buffer(const struct ng_session *cur);

uint32_t get_receive_window(const struct ng_session *cur);

void check_tcp_socket(const struct arguments *args,
                      const struct epoll_event *ev,
                      const int epoll_fd);

int is_lower_layer(int protocol);

int is_upper_layer(int protocol);

void handle_ip(const struct arguments *args,
               const uint8_t *buffer, size_t length,
               const int epoll_fd,
               int sessions, int maxsessions);

jboolean handle_icmp(const struct arguments *args,
                     const uint8_t *pkt, size_t length,
                     const uint8_t *payload,
                     int uid,
                     const int epoll_fd);

int has_udp_session(const struct arguments *args, const uint8_t *pkt, const uint8_t *payload);

jboolean handle_udp(const struct arguments *args,
                    const uint8_t *pkt, size_t length,
                    const uint8_t *payload,
                    int uid, struct allowed *redirect,
                    const int epoll_fd);

void clear_tcp_data(struct tcp_session *cur);

jboolean handle_tcp(const struct arguments *args,
                    const uint8_t *pkt, size_t length,
                    const uint8_t *payload,
                    int uid, int allowed, struct allowed *redirect,
                    const int epoll_fd);

void queue_tcp(const struct arguments *args,
               const struct tcphdr *tcphdr,
               const char *session, struct tcp_session *cur,
               const uint8_t *data, uint16_t datalen);

int open_icmp_socket(const struct arguments *args, const struct icmp_session *cur);

int open_udp_socket(const struct arguments *args,
                    const struct udp_session *cur, const struct allowed *redirect);

int open_tcp_socket(const struct arguments *args,
                    const struct tcp_session *cur, const struct allowed *redirect);

int write_syn_ack(const struct arguments *args, struct tcp_session *cur);

int write_ack(const struct arguments *args, struct tcp_session *cur);

int write_data(const struct arguments *args, struct tcp_session *cur,
               const uint8_t *buffer, size_t length);

int write_fin_ack(const struct arguments *args, struct tcp_session *cur);

void write_rst(const struct arguments *args, struct tcp_session *cur);


ssize_t write_icmp(const struct arguments *args, const struct icmp_session *cur,
                   uint8_t *data, size_t datalen);

ssize_t write_udp(const struct arguments *args, const struct udp_session *cur,
                  uint8_t *data, size_t datalen);

ssize_t write_tcp(const struct arguments *args, const struct tcp_session *cur,
                  const uint8_t *data, size_t datalen,
                  int syn, int ack, int fin, int rst);


uint16_t calc_checksum(uint16_t start, const uint8_t *buffer, size_t length);


void log_android(int prio, const char *fmt, ...);

int compare_u32(uint32_t seq1, uint32_t seq2);


char *hex(const u_int8_t *data, const size_t len);

int is_readable(int fd);


long long get_ms();

void ng_add_alloc(void *ptr, const char *tag);

void ng_delete_alloc(void *ptr, const char *file, int line);

void *ng_malloc(size_t __byte_count, const char *tag);

void *ng_calloc(size_t __item_count, size_t __item_size, const char *tag);

void ng_free(void *__ptr, const char *file, int line);
void log_packet_hex(const struct arguments *args, const uint8_t *data, size_t length, const char *direction);
jboolean filter_tcp_packet(const struct arguments *args, const uint8_t *data, size_t length, const char *direction);
jboolean filter_udp_packet(const struct arguments *args, const uint8_t *data, size_t length, const char *direction);
jboolean filter_icmp_packet(const struct arguments *args, const uint8_t *data, size_t length, const char *direction);