 /****************************************************************************
 *
 * Copyright (c) 2021 Broadcom. All rights reserved
 * The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************
 * Author: Jayesh Patel <jayeshp@broadcom.com>
 ****************************************************************************/

#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/rculist.h>
#include <linux/spinlock.h>
#include <linux/times.h>
#include <linux/netdevice.h>
#include <linux/etherdevice.h>
#include <linux/jhash.h>
#include <linux/random.h>
#include <linux/slab.h>
#include <linux/atomic.h>
#include <asm/unaligned.h>
#include "macpt.h"

/* Following code is derived from br_fdb.c implementation */

#define MACPT_HASH_BITS 8
#define MACPT_HASH_SIZE (1 << MACPT_HASH_BITS)

static spinlock_t			hash_lock;
static struct hlist_head		hash[MACPT_HASH_SIZE];

static struct kmem_cache *macpt_db_cache __read_mostly;

static u32 db_salt __read_mostly;

int macpt_db_init(void)
{
	macpt_db_cache = kmem_cache_create("macpt_db_cache",
					   sizeof(struct macpt_db_entry),
					   0,
					   SLAB_HWCACHE_ALIGN, NULL);
	if (!macpt_db_cache)
		return -ENOMEM;

	get_random_bytes(&db_salt, sizeof(db_salt));

	return 0;
}

void macpt_db_fini(void)
{
	kmem_cache_destroy(macpt_db_cache);
}

static inline int macpt_mac_hash(const unsigned char *mac)
{
	/* use 1 byte of OUI cnd 3 bytes of NIC */
	u32 key = get_unaligned((u32 *)(mac + 2));
	return jhash_1word(key, db_salt) & (MACPT_HASH_SIZE - 1);
}

static void db_rcu_free(struct rcu_head *head)
{
	struct macpt_db_entry *ent
		= container_of(head, struct macpt_db_entry, rcu);
	kmem_cache_free(macpt_db_cache, ent);
}

static void db_delete(struct macpt_db_entry *f)
{
	if (f && atomic_dec_and_test(&f->use)) {
		hlist_del_rcu(&f->hlist);
		call_rcu(&f->rcu, db_rcu_free);
	}
}

/* Completely flush all dynamic entries in client database */
void macpt_db_flush(void)
{
	int i;

	spin_lock_bh(&hash_lock);
	for (i = 0; i < MACPT_HASH_SIZE; i++) {
		struct macpt_db_entry *f;
		struct hlist_node *n;
		hlist_for_each_entry_safe(f, n, &hash[i], hlist) {
			db_delete(f);
		}
	}
	spin_unlock_bh(&hash_lock);
}

/* Flush all entries referring to a specific port.
 * if do_all is set also flush static entries
 */
void macpt_db_delete_by_port(struct net_device *dev)
{
	int i;

	spin_lock_bh(&hash_lock);
	for (i = 0; i < MACPT_HASH_SIZE; i++) {
		struct hlist_node *h, *g;

		hlist_for_each_safe(h, g, &hash[i]) {
			struct macpt_db_entry *f;
			f = hlist_entry(h, struct macpt_db_entry, hlist);
			if (f->dev != dev)
				continue;

			db_delete(f);
		}
	}
	spin_unlock_bh(&hash_lock);
}

/* Internal Function: Find entry based on mac address in specific list */
static struct macpt_db_entry *db_find(struct hlist_head *head,
				      const unsigned char *addr)
{
	struct macpt_db_entry *db;

	hlist_for_each_entry(db, head, hlist) {
		if (ether_addr_equal(db->addr, addr))
			return db;
	}
	return NULL;
}

/* Find entry based on mac address */
struct macpt_db_entry *macpt_db_find(const unsigned char *addr)
{
	struct hlist_head *head = &hash[macpt_mac_hash(addr)];
	struct macpt_db_entry *db = NULL;

	spin_lock_bh(&hash_lock);
	db = db_find(head, addr);
	spin_unlock_bh(&hash_lock);
	return db;
}

/* Internal Function: Create entry based on input interface and mac address
   in specific list */
static struct macpt_db_entry *db_create(struct hlist_head *head,
					struct net_device *dev,
					const unsigned char *addr)
{
	struct macpt_db_entry *db;

	db = kmem_cache_alloc(macpt_db_cache, GFP_ATOMIC);
	if (db) {
		memcpy(db->addr, addr, ETH_ALEN);
		db->dev = dev;
		db->ifindex = 0;
		atomic_set(&db->use, 1);
		db->created = jiffies;
		hlist_add_head_rcu(&db->hlist, head);
	}
	return db;
}

/* Internal Function: Delete entry based on mac address */
static int db_delete_by_addr(const u8 *addr)
{
	struct hlist_head *head = &hash[macpt_mac_hash(addr)];
	struct macpt_db_entry *db;

	db = db_find(head, addr);
	if (!db)
		return -ENOENT;

	db_delete(db);
	return 0;
}

/* Delete entry based on mac address */
int macpt_db_delete(const unsigned char *addr)
{
	int err;

	spin_lock_bh(&hash_lock);
	err = db_delete_by_addr(addr);
	spin_unlock_bh(&hash_lock);
	return err;
}

/* Internal Function: Insert entry based on input interface and mac address */
static int db_insert(struct net_device *dev, const unsigned char *addr)
{
	struct hlist_head *head = &hash[macpt_mac_hash(addr)];
	struct macpt_db_entry *db;

	if (!is_valid_ether_addr(addr))
		return -EINVAL;

	db = db_find(head, addr);
	if (db)
		atomic_inc(&db->use);
	else
		db = db_create(head, dev, addr);
	if (!db)
		return -ENOMEM;

	return 0;
}

/* Insert entry based on input interface and mac address */
int macpt_db_insert(struct net_device *dev, const unsigned char *addr)
{
	int ret;

	spin_lock_bh(&hash_lock);
	ret = db_insert(dev, addr);
	spin_unlock_bh(&hash_lock);
	return ret;
}

/* Internal Function: Update entry based on input interface and mac address */
static int db_update(struct net_device *dev, const unsigned char *addr)
{
	struct hlist_head *head = &hash[macpt_mac_hash(addr)];
	struct macpt_db_entry *db;

	if (!is_valid_ether_addr(addr))
		return -EINVAL;

	db = db_find(head, addr);
	db->dev = dev;
	db->created = jiffies;

	return 0;
}

/* Update entry based on input interface and mac address */
int macpt_db_update(struct net_device *dev, const unsigned char *addr)
{
	int ret;

	spin_lock_bh(&hash_lock);
	ret = db_update(dev, addr);
	spin_unlock_bh(&hash_lock);
	return ret;
}

/* Get first entry in specific list */
struct hlist_node  *macpt_db_get_first(int *hashid)
{
	struct hlist_node *h;
	int i;

	for (i = 0; i < MACPT_HASH_SIZE; i++) {
		h = rcu_dereference(hlist_first_rcu(&hash[i]));
		if (h) {
			*hashid = i;
			return h;
		}
	}
	return NULL;
}

/* Get next entry in specific list */
struct hlist_node *macpt_db_get_next(int *hashid,
				     struct hlist_node *head)
{
	head = rcu_dereference(hlist_next_rcu(head));
	while (!head) {
		if (++*hashid >= MACPT_HASH_SIZE)
			return NULL;
		head = rcu_dereference(
				hlist_first_rcu(
				   &hash[*hashid]));
	}
	return head;
}

/* Get index in specific list */
struct hlist_node *macpt_db_get_idx(int *hashid, loff_t pos)
{
	struct hlist_node *head = macpt_db_get_first(hashid);

	if (head)
		while (pos && (head = macpt_db_get_next(hashid, head)))
			pos--;
	return pos ? NULL : head;
}
