// SPDX-License-Identifier: GPL-2.0-only
/*
 * Connection tracking protocol helper module for ESP.
 *
 */

#include <linux/module.h>
#include <linux/types.h>
#include <linux/timer.h>
#include <linux/list.h>
#include <linux/seq_file.h>
#include <linux/in.h>
#include <linux/netdevice.h>
#include <linux/skbuff.h>
#include <linux/slab.h>
#include <net/dst.h>
#include <net/net_namespace.h>
#include <net/netns/generic.h>
#include <net/netfilter/nf_conntrack_l4proto.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_timeout.h>
#include <linux/netfilter/nf_conntrack_proto_esp.h>
#ifdef CONFIG_NF_CONNTRACK_OFFLOAD
#include <net/netfilter/nf_conntrack_offload.h>
#endif

static const unsigned int esp_timeouts[ESP_CT_MAX] = {
	[ESP_CT_UNREPLIED]	= 30*HZ,
	[ESP_CT_REPLIED]	= 180*HZ,
};

#define IPSEC_INUSE    1
#define MAX_PORTS      16 /* KT: Changed to match MAX_VPN_CONNECTION */
#define TEMP_SPI_START 1500

struct _esp_table {
	u_int32_t l_spi;
	u_int32_t r_spi;
	union nf_inet_addr l_ip;
	union nf_inet_addr r_ip;
	u_int32_t timeout;
	u_int16_t tspi;
	struct nf_conn *ct;
	int    pkt_rcvd;
	int    inuse;
};

static struct _esp_table esp_table[MAX_PORTS];

static u_int16_t next_tspi = TEMP_SPI_START;

static inline struct nf_esp_net *esp_pernet(struct net *net)
{
	return &net->ct.nf_ct_proto.esp;
}

/*
 * Allocate a free ESP table entry.
 */
struct _esp_table *alloc_esp_entry(void)
{
	int idx = 0;

	for (; idx < MAX_PORTS; idx++) {
		if (esp_table[idx].inuse == IPSEC_INUSE)
			continue;

		memset(&esp_table[idx], 0, sizeof(struct _esp_table));
		esp_table[idx].tspi  = next_tspi++;
		esp_table[idx].inuse = IPSEC_INUSE;

		pr_debug("[%d] alloc_entry() tspi(%u)\n",
			 idx, esp_table[idx].tspi);

		return &esp_table[idx];
	}
	return NULL;
}

/*
 * Search an ESP table entry by ct.
 */
static struct _esp_table *search_esp_entry_by_ct(struct nf_conn *ct)
{
	int idx = 0;

	for (; idx < MAX_PORTS; idx++) {
		if (esp_table[idx].inuse != IPSEC_INUSE)
			continue;

		pr_debug("Searching entry->ct(%p) <--> ct(%p)\n",
			 esp_table[idx].ct, ct);

		/* checking ct */
		if (esp_table[idx].ct == ct) {
			pr_debug("Found entry %d with ct(%p)\n", idx, ct);

			return &esp_table[idx];
		}
	}

	pr_debug("No Entry for ct(%p)\n", ct);
	return NULL;
}

/*
 * Search an ESP table entry by source IP.
 * If found one, update the spi value
 */
static struct _esp_table
*search_esp_entry_by_ip(struct nf_conntrack_tuple *tuple,
			const __u32 spi)
{
	int idx = 0;
	union nf_inet_addr *src_ip = &tuple->src.u3;
	union nf_inet_addr *dst_ip = &tuple->dst.u3;
	struct _esp_table *esp_entry = esp_table;

	pr_debug("  Searching for SPI %x by IP %pI4\n", spi, &tuple->src.u3.ip);
	for (; idx < MAX_PORTS; idx++, esp_entry++) {

		/* make sure l_ip is LAN IP */
		if (nf_inet_addr_cmp(src_ip, &esp_entry->l_ip)) {
			pr_debug("  [%d] found SPI 0x%x entry with l_ip, setting r_spi to 0\n", idx, spi);

			/* This is a new connection of the same LAN host */
			if ((nf_inet_addr_cmp(dst_ip, &esp_entry->r_ip)) ||
			    (esp_entry->l_spi != spi)) {
				esp_entry->r_ip = *dst_ip;
				esp_entry->r_spi = 0;
			}
			esp_entry->l_spi = spi;
			return esp_entry;
		} else if (nf_inet_addr_cmp(src_ip, &esp_entry->r_ip)) {
			pr_debug("  [%d] found entry with r_ip\n", idx);
			/* FIXME */
			if (esp_entry->r_spi == 0) {
				pr_debug(
				   "  found entry with "
				   "r_ip and r_spi == 0\n");
				esp_entry->r_spi = spi;
				return esp_entry;
			}
			/* We cannot handle spi changed at WAN side */
			pr_debug("  found entry with r_ip but r_spi != 0\n");
		}
	}
	pr_debug("No Entry for spi(0x%x)\n", spi);
	return NULL;
}

/*
 * Search an ESP table entry by spi
 */
static struct _esp_table
*search_esp_entry_by_spi(const __u32 spi, const union nf_inet_addr *src_ip)
{
	int idx = 0;
	struct _esp_table *esp_entry = esp_table;

	for (; idx < MAX_PORTS; idx++, esp_entry++) {
		pr_debug("  [%d]  Searching for spi 0x%x in 0x%x <=> 0x%x\n",
			 idx, spi, esp_entry->l_spi, esp_entry->r_spi);

		if ((spi == esp_entry->l_spi) || (spi == esp_entry->r_spi)) {
			pr_debug("  [%d] found matching entry tspi %u, lspi 0x%x rspi 0x%x ct 0x%p\n",
				 idx, esp_entry->tspi, esp_entry->l_spi, esp_entry->r_spi, esp_entry->ct);

			/* l_spi and r_spi may be the same */
			if ((spi == esp_entry->l_spi) &&
			    (nf_inet_addr_cmp(&esp_entry->r_ip, src_ip))) {
				pr_debug("  Set l_spi(0x%x) = r_spi\n", spi);
				esp_entry->r_spi = spi;
			}

			return esp_entry;
		}
	}
	pr_debug("No Entry for spi(0x%x)\n", spi);

	return NULL;
}

/* PUBLIC CONNTRACK PROTO HELPER FUNCTIONS */

/* Called when the connection is deleted. */
void esp_destroy(struct nf_conn *ct)
{
	struct _esp_table *esp_entry = NULL;

	pr_debug("DEL ESP entry ct(%p)\n", ct);
	esp_entry = search_esp_entry_by_ct(ct);
	if (esp_entry)
		memset(esp_entry, 0, sizeof(struct _esp_table));
	else
		pr_debug("DEL ESP Failed for ct(%p): no such entry\n", ct);
}

/* esp hdr info to tuple */
bool esp_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff,
		      struct net *net, struct nf_conntrack_tuple *tuple)
{
	struct esphdr _esphdr, *esphdr;
	struct _esp_table *esp_entry = NULL;

	esphdr = skb_header_pointer(skb, dataoff, sizeof(_esphdr), &_esphdr);
	if (!esphdr) {
		/* try to behave like "nf_conntrack_proto_generic" */
		tuple->src.u.all = 0;
		tuple->dst.u.all = 0;
		return true;
	}

	pr_debug("Enter pkt_to_tuple() with spi 0x%x\n", esphdr->spi);
	/* check if esphdr has a new SPI:
	 *   if no, update tuple with correct tspi and increment pkt count;
	 *   if yes, check if we have seen the source IP:
	 *             if yes, do the tspi and pkt count update
	 *             if no, create a new entry
	 */

	esp_entry = search_esp_entry_by_spi(esphdr->spi, &tuple->src.u3);
	if (esp_entry == NULL) {
		esp_entry = search_esp_entry_by_ip(tuple, esphdr->spi);
		if (esp_entry == NULL) {
#if 0
			/* Because SA is simplex, it's possible that WAN
			 * starts connection first.
			 * We need to make sure that the connection
			 * starts from LAN.
			 */
			if (((unsigned char *)&(tuple->src.u3.ip))[0] != 192) {
				pr_debug("src_ip %pI4 is WAN IP, DROP packet\n",
					&tuple->src.u3.ip);
				return false;
			}
#endif
			esp_entry = alloc_esp_entry();
			if (esp_entry == NULL) {
				pr_debug("Too many entries. New spi(0x%x)\n",
					 esphdr->spi);
				return false;
			}

			esp_entry->l_spi = esphdr->spi;
			esp_entry->l_ip = tuple->src.u3;
			esp_entry->r_ip = tuple->dst.u3;

			pr_debug("[index %ld] alloc_entry() tspi(%u) l_spi(0x%x) "
			       "r_spi(0x%x) ct(%p)\n",
				(long int) (esp_entry - &esp_table[0]),
				esp_entry->tspi, esp_entry->l_spi, esp_entry->r_spi,
				esp_entry->ct);
		}

	}

	tuple->dst.u.esp.spi = tuple->src.u.esp.spi = esp_entry->tspi;
	esp_entry->pkt_rcvd++;

	return true;
}

#ifdef CONFIG_NF_CONNTRACK_PROCFS
/* print private data for conntrack */
static void esp_print_conntrack(struct seq_file *s, struct nf_conn *ct)
{
	seq_printf(s, "timeout=%u, stream_timeout=%u ",
		   (ct->proto.esp.timeout / HZ),
		   (ct->proto.esp.stream_timeout / HZ));
}
#endif

static unsigned int *esp_get_timeouts(struct net *net)
{
	return esp_pernet(net)->timeouts;
}

/* Returns verdict for packet, and may modify conntrack */
int nf_conntrack_esp_packet(struct nf_conn *ct,
			    struct sk_buff *skb,
			    unsigned int dataoff,
			    enum ip_conntrack_info ctinfo,
			    const struct nf_hook_state *state)
{
	struct esphdr _esphdr, *esphdr;
	struct iphdr *iph4;
	struct ipv6hdr *iph6;
	struct _esp_table *esp_entry;
	union nf_inet_addr addr;

#ifdef CONFIG_NF_CONNTRACK_OFFLOAD
	struct nf_conn_offload *ct_offload = nf_conn_offload_find(ct);
#endif

	if (!nf_ct_is_confirmed(ct)) {
		unsigned int *timeouts = nf_ct_timeout_lookup(ct);

		if (!timeouts)
			timeouts = esp_get_timeouts(nf_ct_net(ct));

		/* initialize to sane value.  Ideally a conntrack helper
		 * (e.g. in case of pptp) is increasing them */
		ct->proto.esp.stream_timeout = timeouts[ESP_CT_REPLIED];
		ct->proto.esp.timeout = timeouts[ESP_CT_UNREPLIED];
	}

	esphdr = skb_header_pointer(skb, dataoff, sizeof(_esphdr), &_esphdr);

	if (!esphdr)
		return NF_ACCEPT;

	memset(&addr, '\0', sizeof(addr));
	esp_entry = search_esp_entry_by_spi(esphdr->spi, &addr);
	if (esp_entry != NULL)
		esp_entry->ct = ct;
	else
		pr_debug("cannot find an entry with SPI %x\n", esphdr->spi);

	switch (nf_ct_l3num(ct)) {
	case AF_INET:
		iph4 = ip_hdr(skb);
		pr_debug("(0x%x) %pI4 <-> %pI4 status %s info %d %s\n",
		 esphdr->spi, &iph4->saddr, &iph4->daddr,
		 (ct->status & IPS_SEEN_REPLY) ? "SEEN" : "NOT_SEEN",
		 ctinfo, (ctinfo == IP_CT_NEW) ? "CT_NEW" : "SEEN_REPLY");
		break;
	case AF_INET6:
		iph6 = ipv6_hdr(skb);
		pr_debug("(0x%x) %pI6c <-> %pI6c status %s info %d %s\n",
		 esphdr->spi, &iph6->saddr, &iph6->daddr,
		 (ct->status & IPS_SEEN_REPLY) ? "SEEN" : "NOT_SEEN",
		 ctinfo, (ctinfo == IP_CT_NEW) ? "CT_NEW" : "SEEN_REPLY");
		break;
	default:
		BUG();
	}

#ifdef CONFIG_NF_CONNTRACK_OFFLOAD
	if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) {
		ct_offload_orig.spi = esphdr->spi;
	} else {
		ct_offload_repl.spi = esphdr->spi;
	}
#endif
	/* If we've seen traffic both ways, this is a ESP connection.
	 * Extend timeout. */
	if (ct->status & IPS_SEEN_REPLY) {
		nf_ct_refresh_acct(ct, ctinfo, skb,
				   ct->proto.esp.stream_timeout);
		/* Also, more likely to be important, and not a probe. */
		if (!test_and_set_bit(IPS_ASSURED_BIT, &ct->status))
			nf_conntrack_event_cache(IPCT_ASSURED, ct);
	} else
		nf_ct_refresh_acct(ct, ctinfo, skb,
				   ct->proto.esp.timeout);

	return NF_ACCEPT;
}

#ifdef CONFIG_NF_CONNTRACK_TIMEOUT

#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nfnetlink_cttimeout.h>

static int esp_timeout_nlattr_to_obj(struct nlattr *tb[],
				     struct net *net, void *data)
{
	unsigned int *timeouts = data;
	struct nf_esp_net *net_esp = esp_pernet(net);

	if (!timeouts)
		timeouts = esp_get_timeouts(net);
	/* set default timeouts for ESP. */
	timeouts[ESP_CT_UNREPLIED] = net_esp->timeouts[ESP_CT_UNREPLIED];
	timeouts[ESP_CT_REPLIED] = net_esp->timeouts[ESP_CT_REPLIED];

	if (tb[CTA_TIMEOUT_ESP_UNREPLIED]) {
		timeouts[ESP_CT_UNREPLIED] =
			ntohl(nla_get_be32(tb[CTA_TIMEOUT_ESP_UNREPLIED])) * HZ;
	}
	if (tb[CTA_TIMEOUT_ESP_REPLIED]) {
		timeouts[ESP_CT_REPLIED] =
			ntohl(nla_get_be32(tb[CTA_TIMEOUT_ESP_REPLIED])) * HZ;
	}
	return 0;
}

static int
esp_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data)
{
	const unsigned int *timeouts = data;

	if (nla_put_be32(skb, CTA_TIMEOUT_ESP_UNREPLIED,
			 htonl(timeouts[ESP_CT_UNREPLIED] / HZ)) ||
	    nla_put_be32(skb, CTA_TIMEOUT_ESP_REPLIED,
			 htonl(timeouts[ESP_CT_REPLIED] / HZ)))
		goto nla_put_failure;
	return 0;

nla_put_failure:
	return -ENOSPC;
}

static const struct nla_policy
esp_timeout_nla_policy[CTA_TIMEOUT_ESP_MAX+1] = {
	[CTA_TIMEOUT_ESP_UNREPLIED]	= { .type = NLA_U32 },
	[CTA_TIMEOUT_ESP_REPLIED]	= { .type = NLA_U32 },
};
#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */

void nf_conntrack_esp_init_net(struct net *net)
{
	struct nf_esp_net *net_esp = esp_pernet(net);
	int i;

	for (i = 0; i < ESP_CT_MAX; i++)
		net_esp->timeouts[i] = esp_timeouts[i];
}

/* protocol helper struct */
const struct nf_conntrack_l4proto nf_conntrack_l4proto_esp = {
	.l4proto	 = IPPROTO_ESP,
#ifdef CONFIG_NF_CONNTRACK_PROCFS
	.print_conntrack = esp_print_conntrack,
#endif
#if IS_ENABLED(CONFIG_NF_CT_NETLINK)
	.tuple_to_nlattr = nf_ct_port_tuple_to_nlattr,
	.nlattr_tuple_size = nf_ct_port_nlattr_tuple_size,
	.nlattr_to_tuple = nf_ct_port_nlattr_to_tuple,
	.nla_policy	 = nf_ct_port_nla_policy,
#endif
#ifdef CONFIG_NF_CONNTRACK_TIMEOUT
	.ctnl_timeout    = {
		.nlattr_to_obj	= esp_timeout_nlattr_to_obj,
		.obj_to_nlattr	= esp_timeout_obj_to_nlattr,
		.nlattr_max	= CTA_TIMEOUT_ESP_MAX,
		.obj_size	= sizeof(unsigned int) * ESP_CT_MAX,
		.nla_policy	= esp_timeout_nla_policy,
	},
#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
};
