 /****************************************************************************
  *
  * Broadcom Proprietary and Confidential. (c) 2016 Broadcom.
  * All rights reserved.
  * The term "Broadcom" refers to Broadcom Limited and/or its subsidiaries.
  *
  * Unless you and Broadcom execute a separate written software license
  * agreement governing use of this software, this software is licensed to
  * you under the terms of the GNU General Public License version 2 (the
  * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
  * the following added to such license:
  *
  * As a special exception, the copyright holders of this software give you
  * permission to link this software with independent modules, and to copy
  * and distribute the resulting executable under terms of your choice,
  * provided that you also meet, for each linked independent module, the
  * terms and conditions of the license of that module. An independent
  * module is a module which is not derived from this software. The special
  * exception does not apply to any modifications of the software.
  *
  * Notwithstanding the above, under no circumstances may you combine this
  * software in any way with any other Broadcom software provided under a
  * license other than the GPL, without Broadcom's express prior written
  * consent.
  *
  ****************************************************************************/

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/string.h>
#include <linux/uaccess.h>
#include <linux/skbuff.h>
#include <linux/netdevice.h>
#include <linux/netfilter.h>
#include <linux/rculist_nulls.h>
#include <linux/ip.h>
#include <linux/version.h>
#include <net/route.h>
#include <net/dst.h>
#include <net/icmp.h>
#include <net/ip.h>
#include <linux/inetdevice.h>
#include <linux/netfilter_bridge.h>
#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_tuple.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_zones.h>
#include <net/netfilter/nf_conntrack_l4proto.h>
#include <net/netfilter/nf_nat.h>

static int ipv4_get_l4proto(const struct sk_buff *skb, unsigned int nhoff,
			    u_int8_t *protonum)
{
	int dataoff = -1;
	const struct iphdr *iph;
	struct iphdr _iph;

	iph = skb_header_pointer(skb, nhoff, sizeof(_iph), &_iph);
	if (!iph)
		return -1;

	/* Conntrack defragments packets, we might still see fragments
	 * inside ICMP packets though.
	 */
	if (iph->frag_off & htons(IP_OFFSET))
		return -1;

	dataoff = nhoff + (iph->ihl << 2);
	*protonum = iph->protocol;

	/* Check bogus IP headers */
	if (dataoff > skb->len) {
		pr_debug("bogus IPv4 packet: nhoff %u, ihl %u, skblen %u\n",
			 nhoff, iph->ihl << 2, skb->len);
		return -1;
	}
	return dataoff;
}

#if IS_ENABLED(CONFIG_IPV6)
static int ipv6_get_l4proto(const struct sk_buff *skb, unsigned int nhoff,
			    u8 *protonum)
{
	int protoff = -1;
	unsigned int extoff = nhoff + sizeof(struct ipv6hdr);
	__be16 frag_off;
	u8 nexthdr;

	if (skb_copy_bits(skb, nhoff + offsetof(struct ipv6hdr, nexthdr),
			  &nexthdr, sizeof(nexthdr)) != 0) {
		pr_debug("can't get nexthdr\n");
		return -1;
	}
	protoff = ipv6_skip_exthdr(skb, extoff, &nexthdr, &frag_off);
	/*
	 * (protoff == skb->len) means the packet has not data, just
	 * IPv6 and possibly extensions headers, but it is tracked anyway
	 */
	if (protoff < 0 || (frag_off & htons(~0x7)) != 0) {
		pr_debug("can't find proto in pkt\n");
		return -1;
	}

	*protonum = nexthdr;
	return protoff;
}
#endif

static int get_l4proto(const struct sk_buff *skb,
		       unsigned int nhoff, u8 pf, u8 *l4num)
{
	switch (pf) {
	case NFPROTO_IPV4:
		return ipv4_get_l4proto(skb, nhoff, l4num);
#if IS_ENABLED(CONFIG_IPV6)
	case NFPROTO_IPV6:
		return ipv6_get_l4proto(skb, nhoff, l4num);
#endif
	default:
		*l4num = 0;
		break;
	}
	return -1;
}

/*iphdr+icmphdr+4bytepad+iphdr+icmphrd+2byteID+2byteSequence*/
#define MIN_LEN_ICMP_UNREACHABLE_FOR_PING_ECHO 56
static
#if LINUX_VERSION_CODE > KERNEL_VERSION(3, 17, 0)
unsigned int nf_inet_local_in(void *priv,
			      struct sk_buff *skb,
			      const struct nf_hook_state *state)
#else
unsigned int nf_inet_local_in(const struct nf_hook_ops *ops,
			      const struct net_device *in,
			      const struct net_device *out,
			      int (*okfn)(struct sk_buff *))
#endif
{
	u_int8_t protonum;
	struct iphdr *iph;
	int ret;
	struct {
		struct icmphdr	icmp;
		struct iphdr	ip;
	} *inside;
	unsigned int hdrlen = ip_hdrlen(skb);

	ret = get_l4proto(skb, skb_network_offset(skb),
			  state->pf, &protonum);
	if (ret <= 0)
		return NF_ACCEPT;

	if ((protonum != IPPROTO_ICMP))
		return NF_ACCEPT;

	inside = (void *)skb->data + hdrlen;

	if(!(inside->icmp.type == ICMP_DEST_UNREACH && inside->icmp.code == ICMP_PORT_UNREACH))
		return NF_ACCEPT;

	iph = ip_hdr(skb);
	/*If this icmp  unreachable packet is not for  icmp echo, accept it*/
	if (ntohs(iph->tot_len) < MIN_LEN_ICMP_UNREACHABLE_FOR_PING_ECHO)
		return NF_ACCEPT;

	iph = &(inside->ip);
	if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
		return NF_DROP;

	return NF_ACCEPT;
}


static struct nf_hook_ops nf_ops[] __read_mostly = {
	{
		.pf       = NFPROTO_IPV4,
		.priority = INT_MAX,
		.hooknum  = NF_INET_PRE_ROUTING,
		.hook     = nf_inet_local_in,
	},
};

static int __init init(void)
{
	int ret;
	ret = nf_register_net_hooks(&init_net, nf_ops, ARRAY_SIZE(nf_ops));
	pr_info("NF ICMP Check Module Loaded\n");
	return ret;
}

static void __exit fini(void)
{
	nf_unregister_net_hooks(&init_net, nf_ops, ARRAY_SIZE(nf_ops));
	pr_info("NF ICMP Check Module Exit\n");
}

module_init(init);
module_exit(fini);
MODULE_DESCRIPTION("ICMP Check");
MODULE_LICENSE("GPL");
MODULE_ALIAS("ip_nf_icmp_check");
