#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/ip.h>
#include <linux/inet.h>
#include <linux/if_ether.h>
#include <net/gre.h>

#define TEB 0x6558 //Transparent Ethernet Bridging (TEB)

static struct nf_hook_ops nfho;
static unsigned char (*target_mac_addresses)[ETH_ALEN];
static int num_macs = 0;
static char *target_macs[10];
static int dscp_value = 32;

module_param_array(target_macs, charp, &num_macs, 0644);
MODULE_PARM_DESC(target_macs, "Target MAC address list");
module_param(dscp_value, int, 0644);
MODULE_PARM_DESC(dscp_value, "DSCP value for gre encapsulated packet");

int string_to_mac(const char *str, unsigned char *mac)
{
    int i;
    unsigned int value;

    if (strlen(str) != 17)
        return -EINVAL;

    for (i = 0; i < 6; i++)
    {
        if (sscanf(str, "%2x", &value) != 1)
            return -EINVAL;

        mac[i] = (unsigned char)(value);
        str += 3;

        if (i == 5 && *str == ':')
            return -EINVAL;

    }

    return 0;
}

static bool is_target_mac(struct ethhdr *ethhdr)
{
    int i;

    for (i = 0; i < num_macs; i++)
    {
        if (memcmp(ethhdr->h_source, target_mac_addresses[i], ETH_ALEN) == 0 ||
            memcmp(ethhdr->h_dest, target_mac_addresses[i], ETH_ALEN) == 0)
        {
            return true;
        }
    }
    return false;
}

static unsigned int gre_packet_handler(void *priv,
                                       struct sk_buff *skb,
                                       const struct nf_hook_state *state)
{
    struct ipv6hdr *inner_ipv6_header;
    struct iphdr *outer_ip_header, *inner_ip_header;
    __be32 inner_dscp;
    unsigned int outer_dscp;
    unsigned short eth_type;

    if (skb->protocol == htons(ETH_P_IP))
    {
        outer_ip_header = ip_hdr(skb);

            if (outer_ip_header->protocol == IPPROTO_GRE)
            {
            struct gre_base_hdr *gre_hdr = (struct gre_base_hdr *)(skb_network_header(skb) + sizeof(struct iphdr));

            if (gre_hdr->protocol == htons(TEB))
            {
                struct ethhdr *inner_eth_header = (struct ethhdr *)(skb_network_header(skb) + sizeof(struct iphdr) + sizeof(struct gre_base_hdr));

                unsigned int vlan_offset = 0;
                if (ntohs(inner_eth_header->h_proto) == ETH_P_8021Q)
                {
                    vlan_offset = 4;
                }
                eth_type = ntohs(((struct ethhdr *)(skb_network_header(skb) + sizeof(struct iphdr) + sizeof(struct gre_base_hdr) + vlan_offset))->h_proto);

                if (eth_type == ETH_P_IP)
                {
                    inner_ip_header = (struct iphdr *)(skb_network_header(skb) + sizeof(struct iphdr) + sizeof(struct gre_base_hdr) + sizeof(struct ethhdr) + vlan_offset);
                    inner_dscp = (inner_ip_header->tos >> 2) & 0x3F;

                    if (inner_dscp == dscp_value)
                    {
                        if (!is_target_mac(inner_eth_header))
                        {
                            return NF_ACCEPT;
                        }
                        outer_dscp = inner_dscp << 2;
                        outer_ip_header->tos = (outer_ip_header->tos & ~0xFC) | (outer_dscp & 0xFC);
                        outer_ip_header->check = 0;
                        outer_ip_header->check = ip_fast_csum((unsigned char *)outer_ip_header, outer_ip_header->ihl);
                    }
                }
                else if (eth_type == ETH_P_IPV6)
                {
                     inner_ipv6_header = (struct ipv6hdr *)(skb_network_header(skb) + sizeof(struct iphdr) + sizeof(struct gre_base_hdr) + sizeof(struct ethhdr) + vlan_offset);

                     inner_dscp  = (inner_ipv6_header->priority);
                     inner_dscp  = (inner_dscp << 2) & 0xFC;
                     if (inner_dscp == 32)
                     {
                         if (!is_target_mac(inner_eth_header))
                         {
                             return NF_ACCEPT;
                         }
                         outer_dscp = dscp_value << 2;
                         outer_ip_header->tos = (outer_ip_header->tos & ~0xFC) | (outer_dscp & 0xFC);
                         outer_ip_header->check = 0;
                         outer_ip_header->check = ip_fast_csum((unsigned char *)outer_ip_header, outer_ip_header->ihl);
                    }
                }
            }
        }
    }
    return NF_ACCEPT;
}

static int __init gre_dscp_module_init(void)
{
    int i;

    if (num_macs <= 0)
    {
        printk(KERN_ERR "No MAC addresses passed to the module.\n");
        return -EINVAL;
    }

    target_mac_addresses = kmalloc(sizeof(unsigned char[ETH_ALEN]) * num_macs, GFP_KERNEL);

    if (!target_mac_addresses)
    {
        printk(KERN_ERR "Failed to allocate memory for MAC addresses\n");
        return -ENOMEM;
    }

    for (i = 0; i < num_macs; i++)
    {
        printk(KERN_INFO "Target MAC[%d]: %s\n", i, target_macs[i]);
        if (string_to_mac(target_macs[i], target_mac_addresses[i]) < 0)
        {
            printk(KERN_ERR "Invalid MAC address format in target_macs[%d]\n", i);
            kfree(target_mac_addresses);
            return -EINVAL;
        }
    }

    nfho.hook = gre_packet_handler;
    nfho.pf = PF_INET;
    nfho.hooknum = NF_INET_POST_ROUTING;
    nfho.priority = NF_IP_PRI_FIRST;

    nf_register_net_hook(&init_net, &nfho);
    printk(KERN_INFO "GRE DSCP Module Initialized\n");
    return 0;
}

static void __exit gre_dscp_module_exit(void)
{
    kfree(target_mac_addresses);
    nf_unregister_net_hook(&init_net, &nfho);
    printk(KERN_INFO "GRE DSCP Module Exited\n");
}

module_init(gre_dscp_module_init);
module_exit(gre_dscp_module_exit);

MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("xmeshgre kernel module");
