// SPDX-License-Identifier: GPL-2.0-only
/*
 * Connection tracking protocol helper module for AH.
 *
 */

#ifdef CONFIG_BCM_KF_CM

#include <linux/module.h>
#include <linux/types.h>
#include <linux/timer.h>
#include <linux/list.h>
#include <linux/seq_file.h>
#include <linux/in.h>
#include <linux/netdevice.h>
#include <linux/skbuff.h>
#include <linux/slab.h>
#include <net/dst.h>
#include <net/net_namespace.h>
#include <net/netns/generic.h>
#include <net/ah.h>
#include <net/netfilter/nf_conntrack_l4proto.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_timeout.h>
#include <linux/netfilter/nf_conntrack_proto_ah.h>
#ifdef CONFIG_NF_CONNTRACK_OFFLOAD
#include <net/netfilter/nf_conntrack_offload.h>
#endif

static const unsigned int ah_timeouts[AH_CT_MAX] = {
	[AH_CT_UNREPLIED]	= 30 * HZ,
	[AH_CT_REPLIED]	= 180 * HZ,
};

#define IPSEC_INUSE    1
#define MAX_PORTS      32 /* KT: Changed to match MAX_VPN_CONNECTION */
#define TEMP_SPI_START 1500

struct _ah_table {
	u32 l_spi;
	u32 r_spi;
	union nf_inet_addr l_ip;
	union nf_inet_addr r_ip;
	u32 timeout;
	u16 tspi;
	struct nf_conn *ct;
	int    pkt_rcvd;
	int    inuse;
};

static struct _ah_table ah_table[MAX_PORTS];

static u16 next_tspi = TEMP_SPI_START;

static inline struct nf_ah_net *ah_pernet(struct net *net)
{
	return &net->ct.nf_ct_proto.ah;
}

/* Allocate a free AH table entry */
struct _ah_table *alloc_ah_entry(void)
{
	int idx = 0;

	for (; idx < MAX_PORTS; idx++) {
		if (ah_table[idx].inuse == IPSEC_INUSE)
			continue;

		memset(&ah_table[idx], 0, sizeof(struct _ah_table));
		ah_table[idx].tspi  = next_tspi++;
		ah_table[idx].inuse = IPSEC_INUSE;

		pr_debug("[%d] alloc_entry() tspi(%u)\n",
			 idx, ah_table[idx].tspi);

		return &ah_table[idx];
	}
	return NULL;
}

/* Search an AH table entry by ct */
static struct _ah_table *search_ah_entry_by_ct(struct nf_conn *ct)
{
	int idx = 0;

	for (; idx < MAX_PORTS; idx++) {
		if (ah_table[idx].inuse != IPSEC_INUSE)
			continue;

		pr_debug("Searching entry->ct(%p) <--> ct(%p)\n",
			 ah_table[idx].ct, ct);

		/* checking ct */
		if (ah_table[idx].ct == ct) {
			pr_debug("Found entry %d with ct(%p)\n", idx, ct);

			return &ah_table[idx];
		}
	}

	pr_debug("No Entry for ct(%p)\n", ct);
	return NULL;
}

/* Search an AH table entry by source IP.
 * If found one, update the spi value
 */
static struct _ah_table
*search_ah_entry_by_ip(struct nf_conntrack_tuple *tuple,
			const __u32 spi)
{
	int idx = 0;
	union nf_inet_addr *src_ip = &tuple->src.u3;
	union nf_inet_addr *dst_ip = &tuple->dst.u3;
	struct _ah_table *ah_entry = ah_table;

	pr_debug("  Searching for SPI %x by IP %pI4\n", spi, &tuple->src.u3.ip);
	for (; idx < MAX_PORTS; idx++, ah_entry++) {
		/* make sure l_ip is LAN IP */
		if (nf_inet_addr_cmp(src_ip, &ah_entry->l_ip)) {
			pr_debug("  [%d] found SPI 0x%x entry with l_ip, setting r_spi to 0\n",
				 idx, spi);

			/* This is a new connection of the same LAN host */
			if ((nf_inet_addr_cmp(dst_ip, &ah_entry->r_ip)) ||
			    ah_entry->l_spi != spi) {
				ah_entry->r_ip = *dst_ip;
				ah_entry->r_spi = 0;
			}
			ah_entry->l_spi = spi;
			return ah_entry;
		} else if (nf_inet_addr_cmp(src_ip, &ah_entry->r_ip)) {
			pr_debug("  [%d] found entry with r_ip\n", idx);
			/* FIXME */
			if (ah_entry->r_spi == 0) {
				pr_debug("  found entry with r_ip and r_spi == 0\n");
				ah_entry->r_spi = spi;
				return ah_entry;
			}
			/* We cannot handle spi changed at WAN side */
			pr_debug("  found entry with r_ip but r_spi != 0\n");
		}
	}
	pr_debug("No Entry for spi(0x%x)\n", spi);
	return NULL;
}

/* Search an AH table entry by spi */
static struct _ah_table
*search_ah_entry_by_spi(const __u32 spi, const union nf_inet_addr *src_ip)
{
	int idx = 0;
	struct _ah_table *ah_entry = ah_table;

	for (; idx < MAX_PORTS; idx++, ah_entry++) {
		pr_debug("  [%d]  Searching for spi 0x%x in 0x%x <=> 0x%x\n",
			 idx, spi, ah_entry->l_spi, ah_entry->r_spi);

		if (spi == ah_entry->l_spi || spi == ah_entry->r_spi) {
			pr_debug("  [%d] found matching entry tspi %u, lspi 0x%x rspi 0x%x ct 0x%p\n",
				 idx, ah_entry->tspi, ah_entry->l_spi, ah_entry->r_spi, ah_entry->ct);

			/* l_spi and r_spi may be the same */
			if (spi == ah_entry->l_spi &&
			    (nf_inet_addr_cmp(&ah_entry->r_ip, src_ip))) {
				pr_debug("  Set l_spi(0x%x) = r_spi\n", spi);
				ah_entry->r_spi = spi;
			}

			return ah_entry;
		}
	}
	pr_debug("No Entry for spi(0x%x)\n", spi);

	return NULL;
}

/* PUBLIC CONNTRACK PROTO HELPER FUNCTIONS */

/* Called when a new connection for this protocol found. */
bool ah_new(struct nf_conn *ct, const struct sk_buff *skb,
	    unsigned int dataoff, unsigned int *timeouts)
{
	struct _ah_table *ah_entry;
	struct ip_auth_hdr _ahhdr, *ahhdr;
	union nf_inet_addr addr;

	ct->proto.ah.stream_timeout = timeouts[AH_CT_UNREPLIED];
	ct->proto.ah.timeout = timeouts[AH_CT_UNREPLIED];

	ahhdr = skb_header_pointer(skb, dataoff, sizeof(_ahhdr), &_ahhdr);
	if (!ahhdr)
		return false;

	memset(&addr, '\0', sizeof(addr));

	ah_entry = search_ah_entry_by_spi(ahhdr->spi, &addr);
	if (ah_entry) {
		ah_entry->ct = ct;
	} else {
		pr_debug("cannot find an entry with SPI %x\n", ahhdr->spi);
		return false;
	}

	return true;
}

/* Called when the connection is deleted. */
void ah_destroy(struct nf_conn *ct)
{
	struct _ah_table *ah_entry = NULL;

	pr_debug("DEL AH entry ct(%p)\n", ct);
	ah_entry = search_ah_entry_by_ct(ct);
	if (ah_entry)
		memset(ah_entry, 0, sizeof(struct _ah_table));
	else
		pr_debug("DEL AH Failed for ct(%p): no such entry\n", ct);
}

/* ah hdr info to tuple */
bool ah_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff,
		     struct net *net, struct nf_conntrack_tuple *tuple)
{
	struct ip_auth_hdr _ahhdr, *ahhdr;
	struct _ah_table *ah_entry = NULL;

	ahhdr = skb_header_pointer(skb, dataoff, sizeof(_ahhdr), &_ahhdr);
	if (!ahhdr) {
		/* try to behave like "nf_conntrack_proto_generic" */
		tuple->src.u.all = 0;
		tuple->dst.u.all = 0;
		return true;
	}

	pr_debug("Enter pkt_to_tuple() with spi 0x%x\n", ahhdr->spi);
	/* check if ahhdr has a new SPI:
	 *   if no, update tuple with correct tspi and increment pkt count;
	 *   if yes, check if we have seen the source IP:
	 *             if yes, do the tspi and pkt count update
	 *             if no, create a new entry
	 */

	ah_entry = search_ah_entry_by_spi(ahhdr->spi, &tuple->src.u3);
	if (!ah_entry) {
		ah_entry = search_ah_entry_by_ip(tuple, ahhdr->spi);
		if (!ah_entry) {
			ah_entry = alloc_ah_entry();
			if (!ah_entry) {
				pr_debug("Too many entries. New spi(0x%x)\n",
					 ahhdr->spi);
				return false;
			}

			ah_entry->l_spi = ahhdr->spi;
			ah_entry->l_ip = tuple->src.u3;
			ah_entry->r_ip = tuple->dst.u3;

			pr_debug("[index %d] alloc_entry() tspi(%u) l_spi(0x%x) r_spi(0x%x) ct(%p)\n",
				 (int)(ah_entry - &ah_table[0]),
				ah_entry->tspi, ah_entry->l_spi, ah_entry->r_spi,
				ah_entry->ct);
		}
	}

	tuple->dst.u.ah.spi = tuple->src.u.ah.spi = ah_entry->tspi;
	ah_entry->pkt_rcvd++;

	return true;
}

#ifdef CONFIG_NF_CONNTRACK_PROCFS
/* print private data for conntrack */
static void ah_print_conntrack(struct seq_file *s, struct nf_conn *ct)
{
	seq_printf(s, "timeout=%u, stream_timeout=%u ",
		   (ct->proto.ah.timeout / HZ),
		   (ct->proto.ah.stream_timeout / HZ));
}
#endif

static unsigned int *ah_get_timeouts(struct net *net)
{
	return ah_pernet(net)->timeouts;
}

/* Returns verdict for packet, and may modify conntrack */
int nf_conntrack_ah_packet(struct nf_conn *ct,
			   struct sk_buff *skb,
			    unsigned int dataoff,
			    enum ip_conntrack_info ctinfo,
			    const struct nf_hook_state *state)
{
	struct ip_auth_hdr _ahhdr, *ahhdr;
	struct iphdr *iph4;
	struct ipv6hdr *iph6;

#ifdef CONFIG_NF_CONNTRACK_OFFLOAD
	struct nf_conn_offload *ct_offload = nf_conn_offload_find(ct);
#endif

	if (!nf_ct_is_confirmed(ct)) {
		unsigned int *timeouts = nf_ct_timeout_lookup(ct);

		if (!timeouts)
			timeouts = ah_get_timeouts(nf_ct_net(ct));

		/* initialize to sane value.  Ideally a conntrack helper
		 * (e.g. in case of pptp) is increasing them
		 */
		ct->proto.ah.stream_timeout = timeouts[AH_CT_REPLIED];
		ct->proto.ah.timeout = timeouts[AH_CT_UNREPLIED];
	}

	ahhdr = skb_header_pointer(skb, dataoff, sizeof(_ahhdr), &_ahhdr);

	if (!ahhdr)
		return NF_ACCEPT;

	switch (nf_ct_l3num(ct)) {
	case AF_INET:
		iph4 = ip_hdr(skb);
		pr_debug("(0x%x) %pI4 <-> %pI4 status %s info %d %s\n",
			 ahhdr->spi, &iph4->saddr, &iph4->daddr,
		 (ct->status & IPS_SEEN_REPLY) ? "SEEN" : "NOT_SEEN",
		 ctinfo, (ctinfo == IP_CT_NEW) ? "CT_NEW" : "SEEN_REPLY");
		break;
	case AF_INET6:
		iph6 = ipv6_hdr(skb);
		pr_debug("(0x%x) %pI6c <-> %pI6c status %s info %d %s\n",
			 ahhdr->spi, &iph6->saddr, &iph6->daddr,
		 (ct->status & IPS_SEEN_REPLY) ? "SEEN" : "NOT_SEEN",
		 ctinfo, (ctinfo == IP_CT_NEW) ? "CT_NEW" : "SEEN_REPLY");
		break;
	default:
		BUG();
	}

#ifdef CONFIG_NF_CONNTRACK_OFFLOAD
	if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL)
		ct_offload_orig.spi = ahhdr->spi;
	else
		ct_offload_repl.spi = ahhdr->spi;
#endif
	/* If we've seen traffic both ways, this is a AH connection.
	 * Extend timeout.
	 */
	if (ct->status & IPS_SEEN_REPLY) {
		nf_ct_refresh_acct(ct, ctinfo, skb,
				   ct->proto.ah.stream_timeout);
		/* Also, more likely to be important, and not a probe. */
		if (!test_and_set_bit(IPS_ASSURED_BIT, &ct->status))
			nf_conntrack_event_cache(IPCT_ASSURED, ct);
	} else {
		nf_ct_refresh_acct(ct, ctinfo, skb, ct->proto.ah.timeout);
	}

	return NF_ACCEPT;
}

#ifdef CONFIG_NF_CONNTRACK_TIMEOUT

#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nfnetlink_cttimeout.h>

static int ah_timeout_nlattr_to_obj(struct nlattr *tb[],
				    struct net *net, void *data)
{
	unsigned int *timeouts = data;
	struct nf_ah_net *net_ah = ah_pernet(net);

	if (!timeouts)
		timeouts = ah_get_timeouts(net);
	/* set default timeouts for AH. */
	timeouts[AH_CT_UNREPLIED] = net_ah->timeouts[AH_CT_UNREPLIED];
	timeouts[AH_CT_REPLIED] = net_ah->timeouts[AH_CT_REPLIED];

	if (tb[CTA_TIMEOUT_AH_UNREPLIED]) {
		timeouts[AH_CT_UNREPLIED] =
			ntohl(nla_get_be32(tb[CTA_TIMEOUT_AH_UNREPLIED])) * HZ;
	}
	if (tb[CTA_TIMEOUT_AH_REPLIED]) {
		timeouts[AH_CT_REPLIED] =
			ntohl(nla_get_be32(tb[CTA_TIMEOUT_AH_REPLIED])) * HZ;
	}
	return 0;
}

static int
ah_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data)
{
	const unsigned int *timeouts = data;

	if (nla_put_be32(skb, CTA_TIMEOUT_AH_UNREPLIED,
			 htonl(timeouts[AH_CT_UNREPLIED] / HZ)) ||
	    nla_put_be32(skb, CTA_TIMEOUT_AH_REPLIED,
			 htonl(timeouts[AH_CT_REPLIED] / HZ)))
		goto nla_put_failure;
	return 0;

nla_put_failure:
	return -ENOSPC;
}

static const struct nla_policy
ah_timeout_nla_policy[CTA_TIMEOUT_AH_MAX + 1] = {
	[CTA_TIMEOUT_AH_UNREPLIED]	= { .type = NLA_U32 },
	[CTA_TIMEOUT_AH_REPLIED]	= { .type = NLA_U32 },
};
#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */

void nf_conntrack_ah_init_net(struct net *net)
{
	struct nf_ah_net *net_ah = ah_pernet(net);
	int i;

	for (i = 0; i < AH_CT_MAX; i++)
		net_ah->timeouts[i] = ah_timeouts[i];
}

/* protocol helper struct */
const struct nf_conntrack_l4proto nf_conntrack_l4proto_ah = {
	.l4proto	 = IPPROTO_AH,
#ifdef CONFIG_NF_CONNTRACK_PROCFS
	.print_conntrack = ah_print_conntrack,
#endif
#if IS_ENABLED(CONFIG_NF_CT_NETLINK)
	.tuple_to_nlattr = nf_ct_port_tuple_to_nlattr,
	.nlattr_tuple_size = nf_ct_port_nlattr_tuple_size,
	.nlattr_to_tuple = nf_ct_port_nlattr_to_tuple,
	.nla_policy	 = nf_ct_port_nla_policy,
#endif
#ifdef CONFIG_NF_CONNTRACK_TIMEOUT
	.ctnl_timeout    = {
		.nlattr_to_obj	= ah_timeout_nlattr_to_obj,
		.obj_to_nlattr	= ah_timeout_obj_to_nlattr,
		.nlattr_max	= CTA_TIMEOUT_AH_MAX,
		.obj_size	= sizeof(unsigned int) * AH_CT_MAX,
		.nla_policy	= ah_timeout_nla_policy,
	},
#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
};

#endif /* CONFIG_BCM_KF_CM */
