 /****************************************************************************
 *
 * Copyright (c) 2015-2018 Broadcom. All rights reserved
 * The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************/

#include <linux/module.h>
#include <linux/moduleparam.h>
#include <linux/tcp.h>
#include <net/tcp.h>

#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/netfilter/nf_conntrack_expect.h>
#include <net/netfilter/nf_nat.h>
#include <net/netfilter/nf_nat_helper.h>
#include "nf_conntrack_rtsp.h"

/****************************************************************************/
static int modify_ports(struct sk_buff *skb, struct nf_conn *ct,
			enum ip_conntrack_info ctinfo,
			unsigned int protoff,
			int matchoff, int matchlen,
			u_int16_t rtpport, u_int16_t rtcpport,
			char dash, int *delta)
{
	char buf[sizeof("65535-65535")];
	int len;

	if (dash)
		len = snprintf(buf, sizeof(buf), "%hu%c%hu", rtpport, dash, rtcpport);
	else
		len = snprintf(buf, sizeof(buf), "%hu", rtpport);
	if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, protoff,
				      matchoff, matchlen,
				      buf, len)) {
		if (net_ratelimit())
			pr_err("nf_nat_rtsp: nf_nat_mangle_tcp_packet error\n");
		return -1;
	}
	*delta = len - matchlen;
	return 0;
}

/* Handles expected signalling connections and media streams */
static void nf_nat_rtsp_expected(struct nf_conn *ct,
				 struct nf_conntrack_expect *exp)
{
	struct nf_nat_range2 range;

	/* This must be a fresh one. */
	BUG_ON(ct->status & IPS_NAT_DONE_MASK);

	/* Change src to where master sends to, but only if the connection
	 * actually came from the same source. */
	if (nf_inet_addr_cmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3,
			     &ct->master->tuplehash[exp->dir].tuple.src.u3)) {
		range.flags = NF_NAT_RANGE_MAP_IPS;
		range.min_addr = range.max_addr
			= ct->master->tuplehash[!exp->dir].tuple.dst.u3;
		nf_nat_setup_info(ct, &range, NF_NAT_MANIP_SRC);
	}

	/* For DST manip, map port here to where it's expected. */
	range.flags = (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED);
	range.min_proto = range.max_proto = exp->saved_proto;
	range.min_addr = range.max_addr
		= ct->master->tuplehash[!exp->dir].tuple.src.u3;
	nf_nat_setup_info(ct, &range, NF_NAT_MANIP_DST);
}

/****************************************************************************/
/* One data channel */
static int nat_rtsp_channel(struct sk_buff *skb, struct nf_conn *ct,
			    enum ip_conntrack_info ctinfo,
			    unsigned int protoff,
			    unsigned int matchoff, unsigned int matchlen,
			    struct nf_conntrack_expect *rtp_exp, int *delta)
{
	struct nf_conn_help *help = nfct_help(ct);
	struct nf_conntrack_expect *exp;
	int dir = CTINFO2DIR(ctinfo);
	u_int16_t nated_port = 0;
	int exp_exist = 0;

	/* Set expectations for NAT */
	rtp_exp->saved_proto.udp.port = rtp_exp->tuple.dst.u.udp.port;
	rtp_exp->expectfn = nf_nat_rtsp_expected;
	rtp_exp->dir = !dir;

	/* Lookup existing expects */
	spin_lock_bh(&nf_conntrack_expect_lock);
	hlist_for_each_entry(exp, &help->expectations, lnode) {
		if (exp->saved_proto.udp.port ==
		    rtp_exp->saved_proto.udp.port) {
			/* Expectation already exists */
			rtp_exp->tuple.dst.u.udp.port =
				exp->tuple.dst.u.udp.port;
			nated_port = ntohs(exp->tuple.dst.u.udp.port);
			exp_exist = 1;
			break;
		}
	}
	spin_unlock_bh(&nf_conntrack_expect_lock);

	if (exp_exist) {
		if(nf_ct_expect_related(rtp_exp, 0) == 0)
			pr_debug("nf_ct_rtsp: expect RTP ");
		goto modify_message;
	}

	/* Try to get a port. */
	for (nated_port = ntohs(rtp_exp->tuple.dst.u.udp.port);
	      nated_port != 0; nated_port++) {
		rtp_exp->tuple.dst.u.udp.port = htons(nated_port);
		if (nf_ct_expect_related(rtp_exp, 0) == 0)
			break;
	}

	if (nated_port == 0) {	/* No port available */
		if (net_ratelimit())
			pr_err("nf_nat_rtsp: out of UDP ports\n");
		return 0;
	}

modify_message:
	/* Modify message */
	if (modify_ports(skb, ct, ctinfo, protoff, matchoff, matchlen,
			 nated_port, 0, 0, delta) < 0) {
		nf_ct_unexpect_related(rtp_exp);
		return -1;
	}

	/* Success */
	pr_debug("nf_nat_rtsp: expect RTP ");
	nf_ct_dump_tuple(&rtp_exp->tuple);

	return 0;
}

/****************************************************************************/
/* A pair of data channels (RTP/RTCP) */
static int nat_rtsp_channel2(struct sk_buff *skb, struct nf_conn *ct,
			     enum ip_conntrack_info ctinfo,
			     unsigned int protoff,
			     unsigned int matchoff, unsigned int matchlen,
			     struct nf_conntrack_expect *rtp_exp,
			     struct nf_conntrack_expect *rtcp_exp,
			     char dash, int *delta)
{
	struct nf_conn_help *help = nfct_help(ct);
	struct nf_conntrack_expect *exp;
	int dir = CTINFO2DIR(ctinfo);
	u_int16_t nated_port = 0;
	int exp_exist = 0;

	/* Set expectations for NAT */
	rtp_exp->saved_proto.udp.port = rtp_exp->tuple.dst.u.udp.port;
	rtp_exp->expectfn = nf_nat_rtsp_expected;
	rtp_exp->dir = !dir;
	rtcp_exp->saved_proto.udp.port = rtcp_exp->tuple.dst.u.udp.port;
	rtcp_exp->expectfn = nf_nat_rtsp_expected;
	rtcp_exp->dir = !dir;

	/* Lookup existing expects */
	spin_lock_bh(&nf_conntrack_expect_lock);
	hlist_for_each_entry(exp, &help->expectations, lnode) {
		if (exp->saved_proto.udp.port ==
		    rtp_exp->saved_proto.udp.port) {
			/* Expectation already exists */
			rtp_exp->tuple.dst.u.udp.port =
				exp->tuple.dst.u.udp.port;
			rtcp_exp->tuple.dst.u.udp.port =
				htons(ntohs(exp->tuple.dst.u.udp.port) + 1);
			nated_port = ntohs(exp->tuple.dst.u.udp.port);
			exp_exist = 1;
			break;
		}
	}
	spin_unlock_bh(&nf_conntrack_expect_lock);

	if (exp_exist) {
		if(nf_ct_expect_related(rtp_exp, 0) == 0)
			pr_debug("nf_ct_rtsp: expect RTP ");
		if(nf_ct_expect_related(rtcp_exp, 0) == 0)
			pr_debug("nf_ct_rtsp: expect RTCP ");
		goto modify_message;
	}

	/* Try to get a pair of ports. */
	for (nated_port = ntohs(rtp_exp->tuple.dst.u.udp.port) & (~1);
	      nated_port != 0; nated_port += 2) {
		rtp_exp->tuple.dst.u.udp.port = htons(nated_port);
		if (nf_ct_expect_related(rtp_exp, 0) == 0) {
			rtcp_exp->tuple.dst.u.udp.port =
				htons(nated_port + 1);
			if (nf_ct_expect_related(rtcp_exp, 0) == 0)
				break;
			nf_ct_unexpect_related(rtp_exp);
		}
	}

	if (nated_port == 0) {	/* No port available */
		if (net_ratelimit())
			pr_err("nf_nat_rtsp: out of RTP/RTCP ports\n");
		return 0;
	}

modify_message:
	/* Modify message */
	if (modify_ports(skb, ct, ctinfo, protoff, matchoff, matchlen,
			 nated_port, nated_port + 1, dash, delta) < 0) {
		nf_ct_unexpect_related(rtp_exp);
		nf_ct_unexpect_related(rtcp_exp);
		return -1;
	}

	/* Success */
	pr_debug("nf_nat_rtsp: expect RTP ");
	nf_ct_dump_tuple(&rtp_exp->tuple);
	pr_debug("nf_nat_rtsp: expect RTCP ");
	nf_ct_dump_tuple(&rtcp_exp->tuple);

	return 0;
}

/****************************************************************************/
static __be16 lookup_mapping_port(struct nf_conn *ct,
				  enum ip_conntrack_info ctinfo,
				  __be16 port)
{
	struct nf_conn_help *help = nfct_help(ct);
	struct nf_conntrack_expect *exp;

	/* Lookup existing expects */
	pr_debug("nf_nat_rtsp: looking up existing expectations...\n");
	hlist_for_each_entry(exp, &help->expectations, lnode) {
		if (exp->tuple.dst.u.udp.port == port) {
			pr_debug(
			   "nf_nat_rtsp: found port %hu mapped from %hu\n",
			   ntohs(exp->tuple.dst.u.udp.port),
			   ntohs(exp->saved_proto.all));
			return exp->saved_proto.all;
		}
	}

	return htons(0);
}

/****************************************************************************/
static int nat_rtsp_modify_port(struct sk_buff *skb, struct nf_conn *ct,
				enum ip_conntrack_info ctinfo,
				unsigned int protoff,
				unsigned int matchoff, unsigned int matchlen,
				__be16 rtpport, int *delta)
{
	__be16 orig_port;

	orig_port = lookup_mapping_port(ct, ctinfo, rtpport);
	if (orig_port == htons(0)) {
		*delta = 0;
		return 0;
	}
	if (modify_ports(skb, ct, ctinfo, protoff, matchoff, matchlen,
			 ntohs(orig_port), 0, 0, delta) < 0)
		return -1;
	pr_debug("nf_nat_rtsp: Modified client_port from %hu to %hu\n",
		 ntohs(rtpport), ntohs(orig_port));
	return 0;
}

/****************************************************************************/
static int nat_rtsp_modify_port2(struct sk_buff *skb, struct nf_conn *ct,
				 enum ip_conntrack_info ctinfo,
				 unsigned int protoff,
				 unsigned int matchoff, unsigned int matchlen,
				 __be16 rtpport, __be16 rtcpport,
				 char dash, int *delta)
{
	__be16 orig_port;

	orig_port = lookup_mapping_port(ct, ctinfo, rtpport);
	if (orig_port == htons(0)) {
		*delta = 0;
		return 0;
	}
	if (modify_ports(skb, ct, ctinfo, protoff, matchoff, matchlen,
			 ntohs(orig_port), ntohs(orig_port)+1, dash, delta) < 0)
		return -1;
	pr_debug("nf_nat_rtsp: Modified client_port from %hu to %hu\n",
		 ntohs(rtpport), ntohs(orig_port));
	return 0;
}

/****************************************************************************/
static int nat_rtsp_modify_addr(struct sk_buff *skb, struct nf_conn *ct,
				enum ip_conntrack_info ctinfo,
				unsigned int protoff,
				int matchoff, int matchlen, int *delta)
{
	char buf[sizeof("255.255.255.255")];
	int dir = CTINFO2DIR(ctinfo);
	int len;

	/* Change the destination address to FW's WAN IP address */

	len = snprintf(buf, sizeof(buf), "%pI4",
		      &ct->tuplehash[!dir].tuple.dst.u3.ip);
	if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, protoff,
				      matchoff, matchlen,
				      buf, len)) {
		if (net_ratelimit())
			pr_err("nf_nat_rtsp: nf_nat_mangle_tcp_packet error\n");
		return -1;
	}
	*delta = len - matchlen;
	return 0;
}

/****************************************************************************/
static int __init init(void)
{
	BUG_ON(rcu_dereference(nat_rtsp_channel_hook) != NULL);
	BUG_ON(rcu_dereference(nat_rtsp_channel2_hook) != NULL);
	BUG_ON(rcu_dereference(nat_rtsp_modify_port_hook) != NULL);
	BUG_ON(rcu_dereference(nat_rtsp_modify_port2_hook) != NULL);
	BUG_ON(rcu_dereference(nat_rtsp_modify_addr_hook) != NULL);
	rcu_assign_pointer(nat_rtsp_channel_hook, nat_rtsp_channel);
	rcu_assign_pointer(nat_rtsp_channel2_hook, nat_rtsp_channel2);
	rcu_assign_pointer(nat_rtsp_modify_port_hook, nat_rtsp_modify_port);
	rcu_assign_pointer(nat_rtsp_modify_port2_hook, nat_rtsp_modify_port2);
	rcu_assign_pointer(nat_rtsp_modify_addr_hook, nat_rtsp_modify_addr);

	pr_debug("nf_nat_rtsp: init success\n");
	return 0;
}

/****************************************************************************/
static void __exit fini(void)
{
	rcu_assign_pointer(nat_rtsp_channel_hook, NULL);
	rcu_assign_pointer(nat_rtsp_channel2_hook, NULL);
	rcu_assign_pointer(nat_rtsp_modify_port_hook, NULL);
	rcu_assign_pointer(nat_rtsp_modify_port2_hook, NULL);
	rcu_assign_pointer(nat_rtsp_modify_addr_hook, NULL);
	synchronize_rcu();
}

/****************************************************************************/
module_init(init);
module_exit(fini);

MODULE_AUTHOR("Broadcom Corporation");
MODULE_DESCRIPTION("RTSP NAT helper");
MODULE_LICENSE("GPL");
MODULE_ALIAS("ip_nat_rtsp");
