 /****************************************************************************
 *
 * Copyright (c) 2016 Broadcom Ltd.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************
 *   Authors:Ignatius Cheng <ignatius.cheng@broadcom.com>
 *
 *  Feburary, 2016
 *
 ****************************************************************************/
#include <linux/module.h>
#include <linux/netfilter/x_tables.h>
#include <linux/skbuff.h>
#include <linux/printk.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <linux/netfilter/xt_l2ogre_mt.h>
#include <linux/if_vlan.h>
#include <net/gre.h>

/* #define XT_L2OGRE_MT_DEBUG */
#ifdef XT_L2OGRE_MT_DEBUG
static void
xt_l2ogre_dump(const struct sk_buff *skb)
{
	pr_err("xt_l2ogre_dump: head=%x, len=%d, data=%x, data_len=%d, mac_len=%d hdr_len=%d\n"
	       "      (outter)  proto=%x, tp_hdr=%d, net_hdr=%d, mac_hdr=%d,\n"
	       "       (inner)  proto=%x, tp_hdr=%d, net_hdr=%d, mac_hdr=%d\n",
	       (unsigned int)skb->head, skb->len, (unsigned int)skb->data,
	       skb->data_len, skb->mac_len, skb->hdr_len,
	       (unsigned int)skb->protocol, skb->transport_header,
	       skb->network_header, skb->mac_header,
	       (unsigned int)skb->inner_protocol, skb->inner_transport_header,
	       skb->inner_network_header, skb->inner_mac_header);

	pr_err("xt_l2ogre_dump: vlan_proto=0x%04x, vlan_tci=0x%04x, encapsulation=%d\n",
	       (unsigned int)skb->vlan_proto, (unsigned int)skb->vlan_tci,
	       skb->encapsulation);

	print_hex_dump(KERN_ERR, "xt_l2ogre_dump: ", DUMP_PREFIX_ADDRESS, 16, 1,
		       skb->data, skb->len, false);
}
#endif

static bool xt_l2ogre_mt(const struct sk_buff *skb,
	struct xt_action_param *par)
{
	const struct xt_l2ogre_mtinfo *info = par->matchinfo;

	int is_ipv4;
	int direction;

	/* IP parameters */
	const struct ipv6hdr *ipv6h = NULL;
	const struct iphdr *iph = NULL;
	int ip_hlen;	/* ipv6 or ipv4 header length */
	const void *remote;
	const void *local;

	/* GRE parameters */
	struct gre_base_hdr *greh, _greh;
	unsigned int greh_optlen = 0;

	/* Inner VLAN parameters */
	struct vlan_ethhdr *vlan_eh = NULL, _vlan_eh;

	if (par->state->pf != PF_INET && par->state->pf != PF_INET6)
		return false;

	is_ipv4 = (par->state->pf == PF_INET);

	/* direction of the packet, NF_INET_LOCAL_IN or NF_INET_LOCAL_OUT */
	direction = (par->state->hook ==  NF_INET_LOCAL_IN) ? 0 : 1;

	/* IP header Processing */
	if (is_ipv4) {
		iph = ip_hdr(skb);
		ip_hlen = (iph->ihl * 4);

		/* basic check that this is a gre packet */
		if (iph->protocol != IPPROTO_GRE)
			return false;
	} else {
		__u8 nexthdr;
		__be16 frag_off;

		ipv6h = ipv6_hdr(skb);
		nexthdr = ipv6h->nexthdr;
		ip_hlen = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr),
					   &nexthdr, &frag_off);
		if (ip_hlen < 0)
			return false;

		/* basic check that this is a gre packet */
		if (nexthdr != IPPROTO_GRE)
			return false;
	}

	/* GRE header Processing */
	greh = skb_header_pointer(skb, ip_hlen, sizeof(_greh), &_greh);
	if (greh == NULL) {
		par->hotdrop = true;
		return false;
	}

	/* make sure packet ethernet type is GRE over Ethernet 0x6558. */
	if (greh->protocol != __cpu_to_be16(ETH_P_TEB))
		return false;

	if (greh->flags & GRE_CSUM)
		greh_optlen += 4;
	if (greh->flags & GRE_KEY)
		greh_optlen += 4;
	if (greh->flags & GRE_SEQ)
		greh_optlen += 4;

	if (is_ipv4) {
		if (direction) {
			remote = &(iph->daddr);
			local = &(iph->saddr);
		} else {
			remote = &(iph->saddr);
			local = &(iph->daddr);
		}
	} else {
		if (direction) {
			remote = &(ipv6h->daddr);
			local = &(ipv6h->saddr);
		} else {
			remote = &(ipv6h->saddr);
			local = &(ipv6h->daddr);
		}
	}

	/* match Processing */
	if (info->flags & XT_L2OGRE_MT_REMOTE) {
		if (is_ipv4) {
			if ((*(const __be32 *)remote) != info->remote.ip)
				return false;
		} else {
			if (ipv6_addr_cmp((const struct in6_addr *)remote,
					  &info->remote.in6) != 0)
				return false;
		}
	}
	if (info->flags & XT_L2OGRE_MT_LOCAL) {
		if (is_ipv4) {
			if ((*(__be32 *)local) != info->local.ip)
				return false;
		} else {
			if (ipv6_addr_cmp((const struct in6_addr *)local,
					  &info->local.in6) != 0)
				return false;
		}
	}

	/* check GRE key Id field setting matches */
	/* This is equivalent to
	if ((!(info->flags & XT_L2OGRE_MT_USE_KEYID) &&
	     (greh->flags & GRE_KEY))
	    ) || (
	     (info->flags & XT_L2OGRE_MT_USE_KEYID) &&
	     !(greh->flags & GRE_KEY))
	    )
	   )
	*/
	if ((info->flags & XT_L2OGRE_MT_USE_KEYID) ^ (greh->flags & GRE_KEY))
		return false;

	if (info->flags & XT_L2OGRE_MT_USE_KEYID) {
		__be32 *greh_opt, _greh_opt[3];
		__be32 greh_key;
		int offset;
		/* GRE option headers */
		greh_opt = skb_header_pointer(skb, ip_hlen + sizeof(_greh),
					      greh_optlen, _greh_opt);
		if (greh_opt == NULL) {
			par->hotdrop = true;
			return false;
		}

		/* check if there is GRE CSUM field */
		offset = (greh->flags & GRE_CSUM) ? 1 : 0;
		greh_key = greh_opt[offset];

		if (__cpu_to_be32(info->key_id) != greh_key)
			return false;
	}

	/* vlan_id_check */
	if (info->flags & XT_L2OGRE_MT_VLANID_CHECK) {
		__be32 vlan_vid;

		/* make sure skb has enough data to hold all
		inner vlan tag header */
		vlan_eh = skb_header_pointer(skb, ip_hlen + sizeof(_greh) +
					     greh_optlen, sizeof(_vlan_eh),
					     &_vlan_eh);
		if (vlan_eh == NULL) {
			par->hotdrop = true;
			return false;
		}

		/* make sure, there is a VLAN tag */
		if (vlan_eh->h_vlan_proto != __cpu_to_be16(ETH_P_8021Q))
			return false;

		vlan_vid = (vlan_eh->h_vlan_TCI & __cpu_to_be16(VLAN_VID_MASK));

		/* This is equivalent to
		if ((!(info->flags & XT_L2OGRE_MT_VLANID_EXCLUDE) &&
		     (vlan_vid == __cpu_to_be32(info->vlan_id_check))
		    ) || (
		     (info->flags & XT_L2OGRE_MT_VLANID_EXCLUDE) &&
		     (vlan_vid != __cpu_to_be32(info->vlan_id_check))
		    )
		   )
		*/
		if ((!!(info->flags & XT_L2OGRE_MT_VLANID_EXCLUDE)) ^
		    (vlan_vid != __cpu_to_be16(info->vlan_id_check)))
			return false;
	}
	return true;
}

static int xt_l2ogre_mt_check(const struct xt_mtchk_param *par)
{
	/* this check the xt_l2ogre_mtinfo parameters */
	const struct xt_l2ogre_mtinfo *info = par->matchinfo;

	if ((info->flags & XT_L2OGRE_MT_VLANID_CHECK) &&
	    (info->vlan_id_check >= VLAN_N_VID)) {
		pr_err("vlan ID check must be inclusively between 0 and 4095\n");
		return -EINVAL;
	}
	return 0;
}

static void xt_l2ogre_mt_destroy(const struct xt_mtdtor_param *par)
{
	/* const struct xt_l2ogre_mtinfo *info = par->matchinfo; */
	/* nothing to destroy */
}


static struct xt_match xt_l2ogre_mt_reg[] __read_mostly = {
	{
		.family		= NFPROTO_IPV4,
		.proto		= IPPROTO_GRE,
		.name		= "l2ogre",
		.revision	= 0,
		.hooks		= (1 << NF_INET_LOCAL_IN) |
				  (1 << NF_INET_LOCAL_OUT),
		.match		= xt_l2ogre_mt,
		.checkentry	= xt_l2ogre_mt_check,
		.destroy	= xt_l2ogre_mt_destroy,
		.matchsize	= sizeof(struct xt_l2ogre_mtinfo),
		.me		= THIS_MODULE,
	},
#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES)
	{
		.family		= NFPROTO_IPV6,
		.proto		= IPPROTO_GRE,
		.name		= "l2ogre",
		.revision	= 0,
		.hooks		= (1 << NF_INET_LOCAL_IN) |
				  (1 << NF_INET_LOCAL_OUT),
		.match		= xt_l2ogre_mt,
		.checkentry	= xt_l2ogre_mt_check,
		.destroy	= xt_l2ogre_mt_destroy,
		.matchsize	= sizeof(struct xt_l2ogre_mtinfo),
		.me		= THIS_MODULE,
	},
#endif
};

int __init xt_l2ogre_mt_init(void)
{
	return xt_register_matches(xt_l2ogre_mt_reg,
		ARRAY_SIZE(xt_l2ogre_mt_reg));
}

void xt_l2ogre_mt_finish(void)
{
	xt_unregister_matches(xt_l2ogre_mt_reg, ARRAY_SIZE(xt_l2ogre_mt_reg));
}
