 /****************************************************************************
 *
 * Copyright (c) 2015-2018 Broadcom. All rights reserved
 * The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************/

#include <linux/module.h>
#include <linux/moduleparam.h>
#include <linux/netfilter.h>
#include <linux/ip.h>
#include <linux/ctype.h>
#include <linux/inet.h>
#include <linux/in.h>

#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_expect.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include "nf_conntrack_rsvp.h"

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Broadcom Corporation");
MODULE_DESCRIPTION("RSVP connection tracking helper");
MODULE_ALIAS("ip_conntrack_rsvp");

static DEFINE_SPINLOCK(nf_rsvp_lock);

int
(*nf_nat_rsvp_convert_ip_hook)(struct rsvp_hdr *rsvph, struct rsvp_session_hdr *rsvpsh, __be32 ip);
EXPORT_SYMBOL_GPL(nf_nat_rsvp_convert_ip_hook);

/* path types */
typedef enum
{
	RSVP_PATH = 1,
	RSVP_RESV,
	RSVP_PATHERR,
	RSVP_RESVERR,
	RSVP_PATHTEAR,
	RSVP_RESVTEAR,
	RSVP_RESVCONF
} RSVP_TYPE;

/* session types */
typedef enum
{
	RSVP_CLASSNUM_UDP = 1,
	RSVP_CLASSNUM_SENDER = 11
} RSVP_SESSIONCLASSNUM;

static int rsvp_help(struct sk_buff *skb, unsigned int protoff,
		struct nf_conn *ct, enum ip_conntrack_info ctinfo)
{
	int dir = CTINFO2DIR(ctinfo);
	int ret = NF_ACCEPT;
	struct iphdr *iph = ip_hdr(skb);
	struct rsvp_hdr _rsvph, *rsvph;
	struct rsvp_session_hdr *rsvpsh;
	struct nf_conntrack_expect *exp;
	struct nf_conntrack_tuple *tuple;
	unsigned short sourceUdpPort = 0, destUdpPort = 0;
	unsigned char *startdata = NULL, *enddata = NULL, *currentptr = NULL;
	bool lan_to_wan = false;
	typeof(nf_nat_rsvp_convert_ip_hook) nf_nat_rsvp_convert_ip;
	nf_nat_rsvp_convert_ip = rcu_dereference(nf_nat_rsvp_convert_ip_hook);

	/* Get RSVP header */
	rsvph = skb_header_pointer(skb, protoff, sizeof(_rsvph), &_rsvph);
	if (rsvph == NULL)
	    return NF_ACCEPT;

	exp = nf_ct_expect_alloc(ct);
	if (exp == NULL)
	    return NF_DROP;

	pr_debug(
       "dir[%d] src.u3.ip=%pI4, dst.u3.ip=%pI4\n",
       dir,
       &ct->tuplehash[dir].tuple.src.u3.ip,
       &ct->tuplehash[dir].tuple.dst.u3.ip);
	pr_debug(
       "dir[%d] src.u3.ip=%pI4, dst.u3.ip=%pI4\n",
       !dir,
       &ct->tuplehash[!dir].tuple.src.u3.ip,
       &ct->tuplehash[!dir].tuple.dst.u3.ip);
	pr_debug("RSVP msgType %u\n", rsvph->msgType);
	pr_debug("RSVP length %u\n", ntohs(rsvph->length));

	spin_lock_bh(&nf_rsvp_lock);

	if (memcmp(&ct->tuplehash[dir].tuple.dst.u3,
				   &ct->tuplehash[!dir].tuple.src.u3,
				   sizeof(ct->tuplehash[dir].tuple.src.u3))
			    != 0)
            lan_to_wan = false;
	else if (memcmp(&ct->tuplehash[dir].tuple.src.u3,
				   &ct->tuplehash[!dir].tuple.dst.u3,
				   sizeof(ct->tuplehash[dir].tuple.src.u3))
			    != 0)
            lan_to_wan = true;

	startdata = (unsigned char *)rsvph;
	currentptr = startdata + sizeof(_rsvph);
	enddata = startdata + ntohs(rsvph->length);

	while (currentptr < enddata)	{
	    rsvpsh = (struct rsvp_session_hdr *)currentptr;

	    if (rsvpsh->length == 0) {
	        // there is somthing wrong with session, just break out
	        pr_err("found zero length session \n");
	        break;
	    }

	    if (rsvpsh->classNum == RSVP_CLASSNUM_UDP) {
                if (rsvpsh->protocol == IPPROTO_UDP)
                {
                   destUdpPort = rsvpsh->ports;
                }
                if(!lan_to_wan && nf_nat_rsvp_convert_ip && ct->status & IPS_NAT_MASK) {
                    //WAN to LAN
                    nf_nat_rsvp_convert_ip(rsvph, rsvpsh, iph->daddr);
                }
	    }

	    if (rsvpsh->classNum == RSVP_CLASSNUM_SENDER) {
               if(lan_to_wan && nf_nat_rsvp_convert_ip && ct->status & IPS_NAT_MASK) {
                    //LAN to WAN
                    nf_nat_rsvp_convert_ip(rsvph, rsvpsh, iph->saddr);
                }

                if (destUdpPort)
                 {
                    sourceUdpPort = rsvpsh->ports;
                    if(lan_to_wan) {
                        tuple = &ct->tuplehash[dir].tuple;
                        nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT,
                                  nf_ct_l3num(ct),
                                  &tuple->dst.u3, &tuple->src.u3,
                                  IPPROTO_UDP, &destUdpPort, &sourceUdpPort);
                    }
                    else {
                        tuple = &ct->tuplehash[!dir].tuple;
                        nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT,
                                  nf_ct_l3num(ct),
                                  &tuple->src.u3, &tuple->dst.u3,
                                  IPPROTO_UDP, &destUdpPort, &sourceUdpPort);
                    }
                    pr_debug("expect: ");
                    nf_ct_dump_tuple(&exp->tuple);
                    if (nf_ct_expect_related(exp, 0) != 0) {
                        ret = NF_DROP;
                        goto end;
                    }
                 }
	    }

	    currentptr += ntohs(rsvpsh->length);
        }

end:
	nf_ct_expect_put(exp);
	spin_unlock_bh(&nf_rsvp_lock);
	return ret;
}

static const struct nf_conntrack_expect_policy rsvp_exp_policy = {
	.max_expected		= 1,
	.timeout		= 5 * 60,
};

static struct nf_conntrack_helper rsvp __read_mostly = {
	.name           = "RSVP",
	.me         = THIS_MODULE,
	.tuple.src.l3num    = AF_INET,
	.tuple.dst.protonum = IPPROTO_RSVP,
	.help           = rsvp_help,
	.expect_policy      = &rsvp_exp_policy,
};

/* don't make this __exit, since it's called from __init ! */
static void nf_conntrack_rsvp_fini(void)
{
	nf_conntrack_helper_unregister(&rsvp);
}

static int __init nf_conntrack_rsvp_init(void)
{
	int ret = 0;
	pr_debug("nf_ct_rsvp: registering helper\n");
	ret = nf_conntrack_helper_register(&rsvp);
	if (ret) {
                pr_err("nf_ct_rsvp: failed to register helper\n");
                nf_conntrack_rsvp_fini();
                return ret;
	}
	return ret;
}

module_init(nf_conntrack_rsvp_init);
module_exit(nf_conntrack_rsvp_fini);
