 /****************************************************************************
 *
 * Copyright (c) 2021 Broadcom. All rights reserved
 * The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************
 * Author: Jayesh Patel <jayeshp@broadcom.com>
 ****************************************************************************/

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/kthread.h>
#include <linux/timer.h>
#include <linux/etherdevice.h>
#include <net/netevent.h>
#include <linux/rculist_nulls.h>
#include <net/rtnetlink.h>
#include <linux/ethtool.h>
#include <uapi/linux/ethtool.h>
#include "macpt.h"
#include "bcmnethooks.h"

#define VERSION     "0.1"
#define VER_STR     "v" VERSION

int macpt_dbg;
module_param_named(debug, macpt_dbg, int, 0644);
MODULE_PARM_DESC(debug, "Debug level (0 or 1)");

struct net_device *macpt_dev;

static enum bcm_nethook_result bcm_nethook_rx(
	struct net_device *dev, enum bcm_nethook_type type, void *buf)
{
	struct sk_buff *skb = (struct sk_buff *)buf;
	struct ethhdr *eh = eth_hdr(skb);
	struct macpt_db_entry *entry;

	if (!netif_carrier_ok(macpt_dev))
		return BCM_NETHOOK_PASS;

	entry = macpt_db_find(eh->h_source);
	if (entry) {
		macpt_dev->stats.rx_packets++;
		macpt_dev->stats.rx_bytes += skb->len;
		entry->dev = dev;
		entry->ifindex = dev->ifindex;
		skb->dev = macpt_dev;
		/* Assuming all dqnet device has index < 128 */
		skb->mark = 0x6d616300 | (dev->ifindex & 0xff);
		netif_receive_skb(skb);
		return BCM_NETHOOK_CONSUMED;
	}
	return BCM_NETHOOK_PASS;
}

void macpt_register_nethook(void)
{
	struct net_device *dev;
	for_each_netdev(&init_net, dev) {
		int dev_group_type = BCM_NETDEVICE_GROUP_TYPE(dev->group);
		if (dev_group_type == BCM_NETDEVICE_GROUP_LAN) {
			bcm_nethook_register_hook(dev, BCM_NETHOOK_RX_SKB,
						  BCM_NETHOOK_PRIO_MAX-1, "MACPT",
						  bcm_nethook_rx);
			bcm_nethook_enable_hook(dev, BCM_NETHOOK_RX_SKB,
						bcm_nethook_rx, true);
		}
	}
}

static inline
void macpt_unregister_nethook(void)
{
	struct net_device *dev;
	for_each_netdev(&init_net, dev) {
		int dev_group_type = BCM_NETDEVICE_GROUP_TYPE(dev->group);
		if (dev_group_type == BCM_NETDEVICE_GROUP_LAN) {
			bcm_nethook_unregister_hook(dev, BCM_NETHOOK_RX_SKB,
						    bcm_nethook_rx);
		}
	}
}

static int macpt_open(struct net_device *dev)
{
	netif_start_queue(dev);
	netif_carrier_on(dev);
	return 0;
}
static int macpt_stop(struct net_device *dev)
{
	netif_stop_queue(dev);
	netif_carrier_off(dev);
	return 0;
}
static int macpt_tx(struct sk_buff *skb, struct net_device *dev)
{
	struct net_device *dev_for_check = NULL;
	struct ethhdr *eh = eth_hdr(skb);
	u_int64_t if_mask = 0;
	if (is_multicast_ether_addr(eh->h_dest)) {
		struct hlist_node *node;
		struct sk_buff *nskb;
		int hash = 0;
		node = macpt_db_get_first(&hash);
		while (node) {
			struct macpt_db_entry *entry;
			entry = (struct macpt_db_entry *) node;
			if (entry && entry->ifindex &&
                            (dev_for_check = __dev_get_by_index(&init_net, entry->ifindex))){

				pr_debug("macpt_tx: want to send multicast packet(%px) to %px(%s)\n",skb, dev_for_check, dev_for_check->name);
				dev_for_check = NULL;
				if (entry->dev &&
						!(if_mask & (1<<entry->dev->ifindex))) {
					/* Clone and send */
					nskb = skb_clone(skb, GFP_ATOMIC);
					if (nskb) {
						if_mask |= (1<<entry->dev->ifindex);
						nskb->dev = entry->dev;
						pr_debug("macpt_tx: send multicast packet(%px) to %px(%d)\n", skb, entry->dev, entry->ifindex);
						dev_queue_xmit(nskb);
					}
				}
			}else{
				pr_debug("macpt_tx: discard multicast packet(%px) as entry error\n", skb);
				if (entry) {
					entry->ifindex = 0;
					entry->dev = NULL;
				}
			}
			node = macpt_db_get_next(&hash, node);
		}
		if (if_mask) {
			dev->stats.tx_packets++;
			dev->stats.tx_bytes += skb->len;
		}
	} else {
		struct macpt_db_entry *entry;
		entry = macpt_db_find(eh->h_dest);
		if (entry && entry->ifindex &&
                    (dev_for_check = __dev_get_by_index(&init_net, entry->ifindex))){
			pr_debug("macpt_tx: want to send packet(%px) to %px(%s)\n", skb, dev_for_check, dev_for_check->name);
			dev_for_check = NULL;
			if (entry && entry->dev) {
				dev->stats.tx_packets++;
				dev->stats.tx_bytes += skb->len;
				skb->dev = entry->dev;
				pr_debug("macpt_tx: send packet(%px) to %px(%d)\n", skb, entry->dev, entry->ifindex);
				dev_queue_xmit(skb);
				return NET_XMIT_SUCCESS;
			}
		}else{
			pr_debug("macpt_tx: discard packet(%px) as entry error\n", skb);
			if (entry) {
				entry->ifindex = 0;
				entry->dev = NULL;
			}
		}
	}
	dev_kfree_skb_any(skb);
	if (!if_mask)
		dev->stats.tx_dropped++;
	return NET_XMIT_SUCCESS;
}
static int macpt_set_mac_addr(struct net_device *dev, void *p)
{
	int status = 0;
	struct sockaddr *addr = (struct sockaddr *)p;

	if (netif_running(dev)) {
		netdev_err(dev, "Device busy.\n");
		status = -EBUSY;
		goto done;
	}

	memcpy(dev->dev_addr, addr->sa_data, dev->addr_len);

done:
	return status;
}
static struct net_device_stats *
macpt_get_stats(struct net_device *dev)
{
	return &dev->stats;
}

int macpt_db_add(struct ndmsg *ndm, struct nlattr *tb[],
		 struct net_device *dev,
		 const unsigned char *addr,
		 u16 vid,
		 u16 flags,
		 struct netlink_ext_ack *extack)
{
	int err = -EINVAL;

	if (is_unicast_ether_addr(addr))
		err = macpt_db_insert(NULL, addr);

	return err;
}

int macpt_db_del(struct ndmsg *ndm, struct nlattr *tb[],
		 struct net_device *dev,
		 const unsigned char *addr,
		 u16 vid)
{
	int err = -EINVAL;

	if (is_unicast_ether_addr(addr))
		err = macpt_db_delete(addr);

	return err;
}

static void macpt_ethtool_get_drvinfo(struct net_device *dev,
				      struct ethtool_drvinfo *drvinfo)
{
	snprintf(drvinfo->driver, 32, "macpt");
	snprintf(drvinfo->version, 32, VERSION);
}

static const struct ethtool_ops macpt_ethtool_ops = {
	.get_link		= ethtool_op_get_link,
	.get_drvinfo		= macpt_ethtool_get_drvinfo,
};

const struct net_device_ops macpt_netdev_ops = {
	.ndo_open		= macpt_open,
	.ndo_stop		= macpt_stop,
	.ndo_start_xmit		= macpt_tx,
	.ndo_set_mac_address	= macpt_set_mac_addr,
	.ndo_get_stats		= macpt_get_stats,
	.ndo_fdb_add		= macpt_db_add,
	.ndo_fdb_del		= macpt_db_del,
	.ndo_fdb_dump		= ndo_dflt_fdb_dump,
};

#define MFG_MAC_ADDR		{0x02, 0x10, 0x18}
#define ETH_MFG_ALEN		(3)
#define ETH_UNIQUE_ALEN		(ETH_ALEN - ETH_MFG_ALEN)
static u8 mfg_mac_addr[ETH_MFG_ALEN] = MFG_MAC_ADDR;
static char *mac_addr_str;

module_param_named(mac_addr, mac_addr_str, charp, 0444);

static int __init macpt_init(void)
{
	int status;
	u8 mac_unique[ETH_UNIQUE_ALEN];
	macpt_dev = alloc_etherdev(0);
	if (!macpt_dev)
		return -ENOMEM;
	snprintf(macpt_dev->name, IFNAMSIZ, "%s", "macpt");
	macpt_dev->netdev_ops = &macpt_netdev_ops;
	macpt_dev->ethtool_ops = &macpt_ethtool_ops;

	status = register_netdev(macpt_dev);
	if (status) {
		netdev_err(macpt_dev,
			   "Failed to register network device.\n");
		free_netdev(macpt_dev);
	}

	if (!mac_addr_str || (mac_pton(mac_addr_str,
				       macpt_dev->dev_addr) == 0)) {
		memcpy(macpt_dev->dev_addr, mfg_mac_addr, ETH_MFG_ALEN);
		get_random_bytes(mac_unique, ETH_UNIQUE_ALEN);
		memcpy(&macpt_dev->dev_addr[ETH_MFG_ALEN], mac_unique,
		       ETH_UNIQUE_ALEN);
	}
	macpt_register_nethook();
	macpt_procfs_init();
	macpt_db_init();
	pr_info("MACPT Init: %s\n", VER_STR);
	return 0;
}

static void __exit macpt_exit(void)
{
	unregister_netdev(macpt_dev);
	free_netdev(macpt_dev);
	macpt_db_flush();
	macpt_unregister_nethook();
	macpt_procfs_exit();
	macpt_db_fini();
	mac_addr_str = NULL;
	pr_info("MACPT Exit:\n");
}

module_init(macpt_init);
module_exit(macpt_exit);
MODULE_LICENSE("GPL");
