 /****************************************************************************
 *
 * Copyright (c) 2015 Broadcom Corporation
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************/

#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/in.h>
#include <linux/tcp.h>
#include <linux/ip.h>

#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/udp.h>
#include "nf_conntrack_ipsec.h"

MODULE_AUTHOR("Pavan Kumar <pavank@broadcom.com>");
MODULE_DESCRIPTION("Netfilter connection tracking module for ipsec");
MODULE_LICENSE("GPL");
MODULE_ALIAS("ip_conntrack_ipsec");

static DEFINE_SPINLOCK(nf_ipsec_lock);

int
(*nf_nat_ipsec_hook_outbound)(struct sk_buff *skb,
			      struct nf_conn *ct,
			      enum ip_conntrack_info ctinfo) __read_mostly;
EXPORT_SYMBOL_GPL(nf_nat_ipsec_hook_outbound);

int
(*nf_nat_ipsec_hook_inbound)(struct sk_buff *skb,
			     struct nf_conn *ct,
			     enum ip_conntrack_info ctinfo,
			     __be32 lan_ip) __read_mostly;
EXPORT_SYMBOL_GPL(nf_nat_ipsec_hook_inbound);

static void __exit nf_conntrack_ipsec_fini(void);

#define CT_REFRESH_TIMEOUT (60 * HZ)	 /* KT: Changed from 13 Sec to 1 Min */

static unsigned int nf_conntrack_ipsec_refresh_timeout = CT_REFRESH_TIMEOUT;

/* Internal table for ISAKMP */
struct _ipsec_table {
	u_int32_t initcookie;
	__be32 lan_ip;
	struct nf_conn *ct;
	int pkt_rcvd;
	int inuse;
} ipsec_table[MAX_VPN_CONNECTION];

static struct _ipsec_table *ipsec_alloc_entry(int *index)
{
	int idx = 0;

	for (; idx < MAX_VPN_CONNECTION; idx++) {
		if (ipsec_table[idx].inuse)
			continue;

		*index = idx;
		memset(&ipsec_table[idx], 0, sizeof(struct _ipsec_table));

		pr_debug("([%d] alloc_entry()\n", idx);

		return &ipsec_table[idx];
	}

	return NULL;
}

/*
 * Search an IPsec table entry by ct.
 */
struct _ipsec_table *search_ipsec_entry_by_ct(struct nf_conn *ct)
{
	int idx = 0;

	for (; idx < MAX_VPN_CONNECTION; idx++) {
		if (!ipsec_table[idx].inuse)
			continue;

		pr_debug("Searching entry->ct(%px) <--> ct(%px)\n",
			 ipsec_table[idx].ct, ct);

		/* check ct */
		if (ipsec_table[idx].ct == ct) {
			pr_debug("Found entry with ct(%px)\n", ct);

			return &ipsec_table[idx];
		}
	}
	pr_debug("No Entry for ct(%px)\n", ct);
	return NULL;
}

/*
 * Search an IPSEC table entry by the initiator cookie.
 */
struct _ipsec_table*
search_ipsec_entry_by_cookie(struct isakmp_pkt_hdr *isakmph)
{
	int idx = 0;
	struct _ipsec_table *ipsec_entry = ipsec_table;

	for (; idx < MAX_VPN_CONNECTION; idx++) {
		pr_debug("Searching initcookie %x <-> %x\n",
			 ntohl(isakmph->initcookie),
			 ntohl(ipsec_entry->initcookie));

		/* Modified by jayeshp@broadcom.com */
		/*  - Avoid uninitalized entries    */
		if (ipsec_entry->inuse &&
		    (isakmph->initcookie == ipsec_entry->initcookie))
			return ipsec_entry;

		ipsec_entry++;
	}

	return NULL;
}

/*
 * Search an IPSEC table entry by the source IP address.
 */
struct _ipsec_table*
search_ipsec_entry_by_addr(const __be32 lan_ip, int *index)
{
	int idx = 0;
	struct _ipsec_table *ipsec_entry = ipsec_table;

	for (; idx < MAX_VPN_CONNECTION; idx++) {
		pr_debug("Looking up lan_ip=%pI4 table entry %pI4\n",
			 &lan_ip, &ipsec_entry->lan_ip);

		if (ntohl(ipsec_entry->lan_ip) == ntohl(lan_ip)) {
			pr_debug("Search by addr returning entry %px\n",
				 ipsec_entry);

			*index = idx;
			return ipsec_entry;
		}
		ipsec_entry++;
	}

	return NULL;
}

static inline int
ipsec_inbound_pkt(struct sk_buff *skb,
		  struct nf_conn *ct,
		  enum ip_conntrack_info ctinfo,
		  __be32 lan_ip)
{
	typeof(nf_nat_ipsec_hook_inbound) nf_nat_ipsec_inbound;

	pr_debug("inbound ISAKMP packet for LAN %pI4\n", &lan_ip);

	nf_nat_ipsec_inbound = rcu_dereference(nf_nat_ipsec_hook_inbound);
	if (nf_nat_ipsec_inbound && ct->status & IPS_NAT_MASK)
		return nf_nat_ipsec_inbound(skb, ct, ctinfo, lan_ip);

	return NF_ACCEPT;
}

/*
 * For outgoing ISAKMP packets, we need to make sure UDP ports=500
 */
static inline int
ipsec_outbound_pkt(struct sk_buff *skb,
		   struct nf_conn *ct,
		   enum ip_conntrack_info ctinfo)
{
	typeof(nf_nat_ipsec_hook_outbound) nf_nat_ipsec_outbound;

	pr_debug("outbound ISAKMP packet skb(%px)\n", skb);

	nf_nat_ipsec_outbound = rcu_dereference(nf_nat_ipsec_hook_outbound);
	if (nf_nat_ipsec_outbound && ct->status & IPS_NAT_MASK)
		return nf_nat_ipsec_outbound(skb, ct, ctinfo);

	return NF_ACCEPT;
}

/* track cookies inside ISAKMP, call expect_related */
static int conntrack_ipsec_help(struct sk_buff *skb,
				unsigned int protoff,
				struct nf_conn *ct,
				enum ip_conntrack_info ctinfo)
{
	int dir = CTINFO2DIR(ctinfo);
	struct isakmp_pkt_hdr _isakmph, *isakmph = NULL;
	struct _ipsec_table *ipsec_entry = ipsec_table;
	int ret, index = 0;

	pr_debug("skb(%px) skb->data(%px) ct(%px) protoff(%d) offset(%d)\n",
		 skb, skb->data, ct, protoff,
		 (int)(protoff + sizeof(struct udphdr)));

	isakmph = skb_header_pointer(skb, protoff + sizeof(struct udphdr),
				     sizeof(_isakmph), &_isakmph);
	if (isakmph == NULL) {
		pr_debug(
		   "ERR: no full ISAKMP header,"
		   " can't track. isakmph=[%px]\n", isakmph);
		return NF_ACCEPT;
	}

	spin_lock_bh(&nf_ipsec_lock);

	if (dir == IP_CT_DIR_ORIGINAL) {
		int lan_ip = ct->tuplehash[dir].tuple.src.u3.ip;

		/* create one entry in the internal table if a
		   new connection is found */
		ipsec_entry = search_ipsec_entry_by_cookie(isakmph);
		if (ipsec_entry == NULL) {
			/* NOTE: cookies may be updated in the connection */
			ipsec_entry =
				search_ipsec_entry_by_addr(lan_ip, &index);
			if (ipsec_entry == NULL) {
				ipsec_entry = ipsec_alloc_entry(&index);
				if (ipsec_entry == NULL) {
					/* All entries are currently in use */
					pr_debug(
					   "ERR: Too many sessions. ct(%px)\n",
					   ct);
					spin_unlock_bh(&nf_ipsec_lock);
					return NF_DROP;
				}
			} else {
				pr_debug(
				   "EXISTING ipsec_entry[%d] with ct=%px, lan_ip=%pI4,"
				   " initcookie=%x\n",
				   index, ipsec_entry->ct, &ipsec_entry->lan_ip,
				   ntohl(ipsec_entry->initcookie));
			}

			/* KT: Guess it should be here */
			ipsec_entry->ct = ct;
			/* KT: Update our cookie information
			- moved to here */
			ipsec_entry->initcookie = isakmph->initcookie;
			ipsec_entry->lan_ip =
				 ct->tuplehash[dir].tuple.src.u3.ip;
			ipsec_entry->inuse = 1;

			pr_debug(
			   "NEW ipsec_entry[%d] with ct=%px, lan_ip=%pI4,"
			   " initcookie=%x\n",
			   index, ipsec_entry->ct, &ipsec_entry->lan_ip,
			   ntohl(ipsec_entry->initcookie));
		}
		ipsec_entry->pkt_rcvd++;

		pr_debug(
		   "L->W: initcookie=%x, lan_ip=%pI4,"
		   " dir[%d] src.u3.ip=%pI4, dst.u3.ip=%pI4\n",
		   isakmph->initcookie,
		   &ct->tuplehash[dir].tuple.src.u3.ip,
		   dir,
		   &ct->tuplehash[dir].tuple.src.u3.ip,
		   &ct->tuplehash[dir].tuple.dst.u3.ip);

		nf_ct_refresh_acct(ipsec_entry->ct, 0, skb,
				   CT_REFRESH_TIMEOUT);

		ret = ipsec_outbound_pkt(skb, ct, ctinfo);
	} else {
		pr_debug("WAN->LAN ct=%px\n", ct);

		ipsec_entry = search_ipsec_entry_by_cookie(isakmph);
		if (ipsec_entry != NULL) {
			nf_ct_refresh_acct(ipsec_entry->ct, 0, skb,
					   CT_REFRESH_TIMEOUT);
			ipsec_entry->pkt_rcvd++;

			pr_debug(
			   "W->L: initcookie=%x, lan_ip=%pI4,"
			   " dir[%d] src.u3.ip=%pI4, dst.u3.ip=%pI4\n",
			   isakmph->initcookie,
			   &ct->tuplehash[dir].tuple.src.u3.ip,
			   dir,
			   &ct->tuplehash[dir].tuple.src.u3.ip,
			   &ct->tuplehash[dir].tuple.dst.u3.ip);

			ret = ipsec_inbound_pkt(skb, ct, ctinfo,
						ipsec_entry->lan_ip);
		} else {
			pr_debug(
			   "WARNNING: client from WAN tries to connect to"
			   " VPN server in the LAN. ipsec_entry=[%px]\n",
			   ipsec_entry);
			ret = NF_ACCEPT;
		}
	}

	spin_unlock_bh(&nf_ipsec_lock);

	return ret;
}

/* Called when the connection is deleted. */
static void ipsec_destroy(struct nf_conn *ct)
{
	struct _ipsec_table *ipsec_entry = NULL;

	spin_lock_bh(&nf_ipsec_lock);
	pr_debug("DEL IPsec entry ct(%px)\n", ct);
	ipsec_entry = search_ipsec_entry_by_ct(ct);
	if (ipsec_entry)
		memset(ipsec_entry, 0, sizeof(struct _ipsec_table));
	else
		pr_debug("DEL IPsec entry failed: ct(%px)\n", ct);
	spin_unlock_bh(&nf_ipsec_lock);
}

#ifdef CONFIG_SYSCTL

static struct ctl_table_header *nf_ct_netfilter_header;

static struct ctl_table ipsec_sysctl_table[] = {
	{
		.procname	= "nf_conntrack_ipsec_refresh_timeout",
		.data		= &nf_conntrack_ipsec_refresh_timeout,
		.maxlen		= sizeof(int),
		.mode		= 0644,
		.proc_handler	= proc_dointvec_jiffies,
	},
	{ }
};
#endif /* CONFIG_SYSCTL */

static const struct nf_conntrack_expect_policy ipsec_exp_policy = {
	.max_expected	= 3,
	.timeout	= 300,
};

/* ISAKMP protocol helper */
static struct nf_conntrack_helper ipsec __read_mostly = {
	.name = "ipsec",
	.me = THIS_MODULE,
	.tuple.src.l3num = AF_INET,
	.tuple.dst.protonum = IPPROTO_UDP,
	.tuple.src.u.udp.port = __constant_htons(IPSEC_PORT),

	.help = conntrack_ipsec_help,
	.destroy = ipsec_destroy,
	.expect_policy = &ipsec_exp_policy,
};

static int __init nf_conntrack_ipsec_init(void)
{
#ifdef CONFIG_SYSCTL
	nf_ct_netfilter_header = register_net_sysctl(&init_net,
						     "net/netfilter",
						     ipsec_sysctl_table);
#endif /* CONFIG_SYSCTL */

	return nf_conntrack_helper_register(&ipsec);
}

static void __exit nf_conntrack_ipsec_fini(void)
{
#ifdef CONFIG_SYSCTL
	unregister_net_sysctl_table(nf_ct_netfilter_header);
#endif /* CONFIG_SYSCTL */

	nf_conntrack_helper_unregister(&ipsec);
}

module_init(nf_conntrack_ipsec_init);
module_exit(nf_conntrack_ipsec_fini);
