 /****************************************************************************
 *
 * Broadcom Proprietary and Confidential.
 * (c) 2015-2016 Broadcom. All rights reserved.
 * The term Broadcom refers to Broadcom Limited and/or its subsidiaries.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************/
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/etherdevice.h>
#include <linux/ip.h>
#include <net/ip.h>
#include <net/ipv6.h>
#include <linux/tcp.h>
#include <linux/udp.h>
#include "sfap.h"

#define SFAP_FLOW_KMALLOC	1
#define SFAP_FLOW_MAX_DEF	128
#define SFAP_FLOW_HASH_BITS	8
#define SFAP_FLOW_HASH_SIZE	(1 << SFAP_FLOW_HASH_BITS)
#define SFAP_FLOW_HASH_FREE	(SFAP_FLOW_HASH_SIZE)

#define SFAP_FLOW_HASH_LOCK(a)		spin_lock(a)
#define SFAP_FLOW_HASH_UNLOCK(a)	spin_unlock(a)
#define SFAP_FLOW_HASH_LOCK_TYPE	DEFINE_SPINLOCK

static SFAP_FLOW_HASH_LOCK_TYPE(sfap_flow_hash_lock);

static struct hlist_head	sfap_flow_hash[SFAP_FLOW_HASH_SIZE+1];
static int			sfap_flow_max = SFAP_FLOW_MAX_DEF;
#if SFAP_FLOW_KMALLOC
static struct sfap_flow	*sfap_flow_tbl;
#else
static struct sfap_flow	sfap_flow_tbl[SFAP_FLOW_MAX_DEF];
#endif
static bool			sfap_pkt_replace;
static atomic_t			sfap_flow_active;
static int			sfap_flow_active_max;
static long sfap_flow_match_mac_cnt;
static long sfap_flow_match_ipv4_cnt;
static long sfap_flow_match_ipv6_cnt;

long sfap_flow_match_get_mac_cnt(void)
{
	return sfap_flow_match_mac_cnt;
}
long sfap_flow_match_get_ipv4_cnt(void)
{
	return sfap_flow_match_ipv4_cnt;
}
long sfap_flow_match_get_ipv6_cnt(void)
{
	return sfap_flow_match_ipv6_cnt;
}

int sfap_flow_get_total(void)
{
	return sfap_flow_max;
}

void sfap_flow_stats_clear(void)
{
	atomic_set(&sfap_flow_active, 0);
	sfap_flow_active_max = 0;
	sfap_flow_match_mac_cnt = 0;
	sfap_flow_match_ipv4_cnt = 0;
	sfap_flow_match_ipv6_cnt = 0;
}

int sfap_flow_get_active(void)
{
	return atomic_read(&sfap_flow_active);
}

int sfap_flow_get_active_max(void)
{
	return sfap_flow_active_max;
}

void sfap_flow_mainp_pkt_enable(bool enable)
{
	sfap_pkt_replace = enable;
}

struct sfap_flow *sfap_flow_get_by_index(int index)
{
	return &(sfap_flow_tbl[index]);
}

struct sfap_flow *sfap_flow_free_get_by_index(int index)
{
	struct sfap_flow  *flow = NULL;
	int i = 0;

	rcu_read_lock();
	hlist_iterate_rcu(flow,
			  &(sfap_flow_hash[SFAP_FLOW_HASH_FREE]),
			  hlist) {
		if (i == index)
			goto out;
		i++;
	}
	flow = NULL;
out:
	rcu_read_unlock();
	return flow;
}

static inline u32 _hash(u32 hash_val)
{
	hash_val ^= (hash_val >> 16);
	hash_val ^= (hash_val >>  8);
	hash_val ^= (hash_val >>  3);

	return hash_val & (SFAP_FLOW_HASH_SIZE-1);
}

/**
 * compare_ipv6_addr - Compare two ipv6 addresses
 * @addr1: Pointer to a ipv6 address
 * @addr2: Pointer other ipv6 address
 *
 * Compare two ipv6 addresses, returns 0 if equal
 */
static inline int compare_ipv6_addr(__be32 *src_ip, __be32 *dst_ip)
{
	return ((src_ip[0] ^ dst_ip[0]) |
			(src_ip[1] ^ dst_ip[1]) |
			(src_ip[2] ^ dst_ip[2]) |
			(src_ip[3] ^ dst_ip[3])) != 0;
}

static inline int sfap_flow_ipv4_hash(__be32 src_ip, __be32 dst_ip,
				       __be16 src_port, __be16 dst_port)
{
	return _hash(src_ip + dst_ip + src_port + dst_port);
}

static inline int sfap_flow_ipv6_hash(__be32 *src_ip, __be32 *dst_ip,
				       __be16 src_port, __be16 dst_port)
{
	return _hash(src_ip[2] ^ src_ip[3] ^
		     dst_ip[2] ^ dst_ip[3] ^
		     src_port ^ dst_port);
}

static inline int sfap_flow_mac_hash(u8 *src_mac,
				      u8 *dst_mac,
				      int rx_interface)
{
	return _hash((((u32)src_mac[2]<<24) + ((u32)src_mac[3]<<16) +
		      ((u32)src_mac[4]<<8) + (u32)src_mac[5]) +
		     (((u32)dst_mac[2]<<24) + ((u32)dst_mac[3]<<16) +
		      ((u32)dst_mac[4]<<8) + (u32)dst_mac[5]) +
		     rx_interface);
}

static void sfap_flow_free_rcu(struct rcu_head *h)
{
	struct sfap_flow *flow = container_of(h, struct sfap_flow, rcu);
	memset(&flow->params, 0, sizeof(flow->params));
	memset(&flow->stats, 0, sizeof(struct sfap_stats));
	flow->hash = SFAP_FLOW_HASH_FREE;
	flow->params.type = ft_none;
	atomic_dec(&sfap_flow_active);
	ktime_get_real_ts64(&flow->ctime);
	ktime_get_real_ts64(&flow->utime);
}

int sfap_flow_create(enum flow_type type, int hash)
{
	struct hlist_node *hlist = sfap_flow_hash[SFAP_FLOW_HASH_FREE].first;
	struct sfap_flow *flow;
	if (!hlist)
		return -1;
	if (hash >= SFAP_FLOW_HASH_SIZE)
		return -1;
	/* Remove flow from free hash list */
	flow = (struct sfap_flow *)hlist;
	SFAP_FLOW_HASH_LOCK(&sfap_flow_hash_lock);
	hlist_del_rcu(hlist);

	/* Update flow parameters */
	flow->hash = hash;
	flow->params.type = type;
	ktime_get_real_ts64(&flow->ctime);
	ktime_get_real_ts64(&flow->utime);
	atomic_inc(&sfap_flow_active);
	if (sfap_flow_active_max < atomic_read(&sfap_flow_active))
		sfap_flow_active_max = atomic_read(&sfap_flow_active);
	/* Add flow to used hash list */
	hlist_add_head_rcu(hlist, &(sfap_flow_hash[hash]));
	SFAP_FLOW_HASH_UNLOCK(&sfap_flow_hash_lock);
	return flow->index;
}

struct sfap_flow *sfap_flow_create_ipv4(__be32 src_ip, __be32 dst_ip,
					__be16 src_port, __be16 dst_port)
{
	struct sfap_flow *flow = NULL;
	int hash = sfap_flow_ipv4_hash(src_ip, dst_ip, src_port, dst_port);
	int index = sfap_flow_create(ft_ipv4, hash);
	if ((index >= 0) && (index < sfap_flow_max))
		flow = sfap_flow_get_by_index(index);
	return flow;
}

struct sfap_flow *sfap_flow_create_ipv6(__be32 *src_ip, __be32 *dst_ip,
					__be16 src_port, __be16 dst_port)
{
	struct sfap_flow *flow = NULL;
	int hash = sfap_flow_ipv6_hash(src_ip, dst_ip, src_port, dst_port);
	int index = sfap_flow_create(ft_ipv6, hash);
	if ((index >= 0) && (index < sfap_flow_max))
		flow = sfap_flow_get_by_index(index);
	return flow;
}

struct sfap_flow *sfap_flow_create_mac_bridge(u8 *src_mac, u8 *dst_mac,
					      int rx_interface)
{
	struct sfap_flow *flow = NULL;
	int hash = sfap_flow_mac_hash(src_mac, dst_mac, rx_interface);
	int index = sfap_flow_create(ft_mac_bridge, hash);
	if ((index >= 0) && (index < sfap_flow_max))
		flow = sfap_flow_get_by_index(index);
	return flow;
}

void sfap_flow_delete(int index)
{
	if ((index >= 0) && (index < sfap_flow_max)) {
		struct sfap_flow *flow = &sfap_flow_tbl[index];
		struct hlist_node *hlist = &(flow->hlist);
		SFAP_FLOW_HASH_LOCK(&sfap_flow_hash_lock);
		/* Remove it from used hash list */
		hlist_del_rcu(hlist);
		call_rcu(&flow->rcu, sfap_flow_free_rcu);
		/* Add it to free hash list */
		hlist_add_head_rcu(hlist,
				   &sfap_flow_hash[SFAP_FLOW_HASH_FREE]);
		SFAP_FLOW_HASH_UNLOCK(&sfap_flow_hash_lock);
	}
}

void sfap_flow_delete_all(void)
{
	int index;
	for (index = 0; index < sfap_flow_max; index++)
		sfap_flow_delete(index);
}

struct sfap_flow  *sfap_flow_find_ipv4(__u8 ip_prot,
				       __be32 src_ip, __be32 dst_ip,
				       __be16 src_port, __be16 dst_port)
{
	int hash = sfap_flow_ipv4_hash(src_ip, dst_ip, src_port, dst_port);
	struct sfap_flow  *flow = NULL;

	rcu_read_lock();
	hlist_iterate_rcu(flow, &(sfap_flow_hash[hash]), hlist) {
		struct flow_ipv4_params *params = &flow->params.ipv4;
		if ((params->ip_prot  == ip_prot) &&
		    (params->src_ip   == src_ip) &&
		    (params->dst_ip   == dst_ip) &&
		    (params->src_port == src_port) &&
		    (params->dst_port == dst_port)) {
			goto out;
		}
	}
	flow = NULL;
out:
	rcu_read_unlock();
	return flow;
}

struct sfap_flow  *sfap_flow_find_ipv6(__u8 ip_prot,
				       __be32 *src_ip, __be32 *dst_ip,
				       __be16 src_port, __be16 dst_port)
{
	int hash = sfap_flow_ipv6_hash(src_ip, dst_ip, src_port, dst_port);
	struct sfap_flow  *flow = NULL;

	rcu_read_lock();
	hlist_iterate_rcu(flow, &(sfap_flow_hash[hash]), hlist) {
		struct flow_ipv6_params *params = &flow->params.ipv6;
		if ((params->ip_prot == ip_prot) &&
		    (!compare_ipv6_addr((__be32 *)params->src_ip, src_ip)) &&
		    (!compare_ipv6_addr((__be32 *)params->dst_ip, dst_ip)) &&
		    (params->src_port == src_port) &&
		    (params->dst_port == dst_port)) {
			goto out;
		}
	}
	flow = NULL;
out:
	rcu_read_unlock();
	return flow;
}

struct sfap_flow  *sfap_flow_find_mac_bridge(u8 *src_mac,
					     u8 *dst_mac,
					     int rx_interface)
{
	int hash = sfap_flow_mac_hash(src_mac, dst_mac, rx_interface);
	struct sfap_flow  *flow = NULL;

	rcu_read_lock();
	hlist_iterate_rcu(flow, &(sfap_flow_hash[hash]), hlist) {
		struct flow_mac_bridge_params *params;
		params = &flow->params.mac_bridge;
		if (ether_addr_equal(params->mac_src, src_mac) &&
		    ether_addr_equal(params->mac_dst, dst_mac) &&
		    (params->rx_interface == rx_interface)) {
			goto out;
		}
	}
	flow = NULL;
out:
	rcu_read_unlock();
	return flow;
}

int sfap_flow_mainp_pkt(unsigned char *data, int len, struct net_device *dev)
{
	struct sfap_flow  *flow;
	struct ethhdr *eh = (struct ethhdr *) data;
	__be16 eth_proto = ntohs(eh->h_proto);

	/* L2 processing - assuming eth_type_trans is called in the driver
	   it pulls the ETH_HLEN from skb */
	if (!is_valid_ether_addr(eh->h_source))
		return -1;
	if (is_multicast_ether_addr(eh->h_dest))
		return -1;
	if (is_broadcast_ether_addr(eh->h_dest))
		return -1;
	flow = sfap_flow_find_mac_bridge(eh->h_source,
					  eh->h_dest,
					  dev->ifindex);
	if (flow) {
		struct flow_mac_bridge_params *params;
		params = &flow->params.mac_bridge;
		/* Bridge flow found */
		sfap_flow_match_mac_cnt++;
		if (!sfap_pkt_replace) {
			flow->stats.packets++;
			flow->stats.bytes += len;
			return -1;
		}
		memcpy(eh->h_source, params->mac_src, ETH_ALEN);
		memcpy(eh->h_dest,   params->mac_dst, ETH_ALEN);
		flow->stats.packets++;
		flow->stats.bytes += len;
		ktime_get_real_ts64(&flow->utime);
	} else if (eth_proto == ETH_P_IP) {
		struct iphdr *iph =  (struct iphdr *) (data + ETH_HLEN);
		struct tcphdr *tcph = NULL;
		struct udphdr *udph = NULL;
		__sum16	*tcpudp_check = NULL;
		__u8 ip_prot  = iph->protocol;
		__be32 src_ip = iph->saddr;
		__be32 dst_ip = iph->daddr;
		__be16 src_port, dst_port;
		u32 iplen;
		struct flow_ipv4_params *params;
		/*
		 *	Is the datagram acceptable?
		 *
		 *	1.	Length at least the size of an ip header
		 *	2.	Version of 4
		 *	3.	Checksums correctly.
		 *	4.	Doesn't have a bogus length
		 */
		if (iph->ihl < 5 || iph->version != 4)
			return -1;
		if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
			return -1;
		iplen = ntohs(iph->tot_len);
		if (len < iplen)
			return -1;
		else if (iplen < (iph->ihl*4))
			return -1;
		if (iph->ttl == 1)
			return -1;
		if (ip_prot == IPPROTO_TCP) {
			tcph = (struct tcphdr *)((char *)iph + iph->ihl*4);
			src_port = tcph->source;
			dst_port = tcph->dest;
			tcpudp_check = &tcph->check;
			if (tcph->syn || tcph->fin || tcph->rst)
				return -1;
		} else if (ip_prot == IPPROTO_UDP) {
			udph = (struct udphdr *)((char *)iph + iph->ihl*4);
			src_port = udph->source;
			dst_port = udph->dest;
			tcpudp_check = &udph->check;
		} else
			return -1;
		flow = sfap_flow_find_ipv4(ip_prot,
					   src_ip, dst_ip,
					   src_port, dst_port);
		if (!flow)
			return -1;
		params = &flow->params.ipv4;
		sfap_flow_match_ipv4_cnt++;
		if (!sfap_pkt_replace) {
			flow->stats.packets++;
			flow->stats.bytes += len;
			return -1;
		}
		/* Replace L2 header */
		memcpy(eh->h_source,
		       params->replacement_mac_src,
		       ETH_ALEN);
		memcpy(eh->h_dest,
		       params->replacement_mac_dst,
		       ETH_ALEN);
		/* Replace L3 and L4 header */
		if (params->type == ft_ipv4_nat_src) {
			csum_replace4(&iph->check,
				      src_ip,
				      params->replacement_ip);
			csum_replace4(tcpudp_check,
				      src_ip,
				      params->replacement_ip);
			csum_replace2(tcpudp_check,
				      src_port,
				      params->replacement_port);
			iph->saddr = params->replacement_ip;
			if (ip_prot == IPPROTO_TCP)
				tcph->source = params->replacement_port;
			else
				udph->source = params->replacement_port;
		} else {
			csum_replace4(&iph->check,
				      dst_ip,
				      params->replacement_ip);
			csum_replace4(tcpudp_check,
				      dst_ip,
				      params->replacement_ip);
			csum_replace2(tcpudp_check,
				      dst_port,
				      params->replacement_port);
			iph->daddr = params->replacement_ip;
			if (ip_prot == IPPROTO_TCP)
				tcph->dest = params->replacement_port;
			else
				udph->dest = params->replacement_port;
		}
		/* TTL decrement */
		csum_replace2(&iph->check,
			      htons(iph->ttl << 8),
			      htons((iph->ttl-1) << 8));
		iph->ttl--;
		ktime_get_real_ts64(&flow->utime);
		flow->stats.packets++;
		flow->stats.bytes += len;
	} else if (eth_proto == ETH_P_IPV6) {
		struct ipv6hdr *ipv6h = NULL;
		struct tcphdr *tcph = NULL;
		struct udphdr *udph = NULL;
		__u8 ip_prot;
		__be16 src_port, dst_port;
		struct flow_ipv6_params *params;
		ipv6h =  (struct ipv6hdr *) (data + ETH_HLEN);
		ip_prot  = ipv6h->nexthdr;
		if (ipv6h->version != 6)
			return -1;
		if (ipv6h->hop_limit == 1)
			return -1;
		if (ip_prot == IPPROTO_TCP) {
			tcph = (struct tcphdr *)((char *)ipv6h + 40);
			src_port = tcph->source;
			dst_port = tcph->dest;
			if (tcph->syn || tcph->fin || tcph->rst)
				return -1;
		} else if (ip_prot == IPPROTO_UDP) {
			udph = (struct udphdr *)((char *)ipv6h + 40);
			src_port = udph->source;
			dst_port = udph->dest;
		} else
			return -1;
		flow = sfap_flow_find_ipv6(ip_prot,
					    ipv6h->saddr.in6_u.u6_addr32,
					    ipv6h->daddr.in6_u.u6_addr32,
					    src_port, dst_port);
		if (!flow)
			return -1;
		params = &flow->params.ipv6;
		sfap_flow_match_ipv6_cnt++;
		if (!sfap_pkt_replace) {
			flow->stats.packets++;
			flow->stats.bytes += len;
			return -1;
		}
		/* Replace L2 header */
		memcpy(eh->h_source,
		       params->replacement_mac_src,
		       ETH_ALEN);
		memcpy(eh->h_dest,
		       params->replacement_mac_dst,
		       ETH_ALEN);
		ipv6h->hop_limit--;
		ktime_get_real_ts64(&flow->utime);
		flow->stats.packets++;
		flow->stats.bytes += len;
	}
	if (flow)
		return flow->params.tx.interface;
	else
		return -1;
}

int sfap_flow_init(int num_flows)
{
	int i;
	sfap_flow_stats_clear();
	for (i = 0; i < SFAP_FLOW_HASH_SIZE+1; i++)
		INIT_HLIST_HEAD(&(sfap_flow_hash[i]));
#if SFAP_FLOW_KMALLOC
	sfap_flow_tbl = kmalloc((sizeof(struct sfap_flow) * (num_flows+1)),
				 GFP_KERNEL);
	sfap_flow_max = num_flows;
	if (sfap_flow_tbl == NULL)
		return -1;
#endif
	for (i = 0; i < sfap_flow_max; i++) {
		memset(&(sfap_flow_tbl[i]), 0, sizeof(sfap_flow_tbl[0]));
		sfap_flow_tbl[i].index = i;
		sfap_flow_tbl[i].hash = SFAP_FLOW_HASH_FREE;
		hlist_add_head(&(sfap_flow_tbl[i].hlist),
			       &sfap_flow_hash[SFAP_FLOW_HASH_FREE]);
	}
	return 0;
}

void sfap_flow_exit(void)
{
#if SFAP_FLOW_KMALLOC
	kfree(sfap_flow_tbl);
#endif
}
