import logging
import re
import socket
import requests
import urllib3
import subprocess
import base64
import select
import struct
import yaml
import psutil
import platform
import ipaddress
import time
import threading
from requests_toolbelt.adapters import source

NIC_PREFIXES = ["tun", "wg", "utun"]
UPDATE_WG = True
# Blacklist is routes we will ignore from VNS3.
# Whitelist is routes we will ignore from the OS.
# Most of the time, routes in the whitelist should also be in the blacklist
# since we otherwise may try to add a route that already exists in the OS.
# Being in both lists basically means "ignore".
# Implemented as string matching like openvpn:
# "0." will match 0.0.0.0/0, 0.1.2.3/32, etc
# "0.0.0.0" will match 0.0.0.0/0, 0.0.0.0/1, 0.0.0.0/16..
# "0.0.0.0/1" will match 0.0.0.0/1, 0.0.0.0/10, 0.0.0.0/16..
# "0.0.0.0/1 " will match only 0.0.0.0/1
ROUTES_BLACKLIST = ["0.0.0.0"]
ROUTES_WHITELIST = ["0.0.0.0", "128.0.0.0/1 "]
NIC_CHECK_INTERVAL = 5
API_CERT_CHECK = False

# Constants
MULTICAST_GROUP = '224.99.102.116'
MULTICAST_PORT = 30216
API_KEY = 'overlayapi'
API_SECRET = 'x'

# Instantiate with an interface name, then call .start() and .stop() from the managing class to begin and end thread execution.
# Opens a socket on a given interface, joins multicast group, listens for multicast or udp unicast packets on given port.
# Expects to receive packets from VNS3 indicating route update and providing VIP for API call to get_connected_subnets.
# Uses an instance of RouteManager to keep system routes in sync with connected_subnets by diffing against system routes.
class RoutingAgent:
    def __init__(self, interface_name, group=MULTICAST_GROUP, port=MULTICAST_PORT, api_key=API_KEY, api_secret=API_SECRET, api_cert_check=API_CERT_CHECK, routes_blacklist=ROUTES_BLACKLIST, 
        routes_whitelist=ROUTES_WHITELIST, update_wg=UPDATE_WG, logger=None):
        self.logger = logger
        self.running = False
        self.thread = None
        self.sock = None
        self.mcsock = None
        self.os_name = platform.system().lower()
        self.group = group
        self.port = port
        self.api_key = api_key
        self.api_secret = api_secret
        self.api_cert_check = api_cert_check
        self.interface = {}
        self.interface['name'] = interface_name
        self.populate_interface_info()
        if not routes_blacklist:
            self.routes_blacklist = []
        else:
            self.routes_blacklist = routes_blacklist
        if not routes_whitelist:
            self.routes_whitelist = []
        else:
            self.routes_whitelist = routes_whitelist
        self.route_manager = RouteManager(self.interface, update_wg=update_wg, logger=self.logger)

    # Fills out the interface dict with other relevant info we will need later; addr, addr6, cidr, indx..
    def populate_interface_info(self):
        ifaces = psutil.net_if_addrs()
        if self.interface['name'] in ifaces:
            for addr in ifaces[self.interface['name']]:
                if addr.family == socket.AF_INET and not addr.address.startswith("127"):
                    netmask = addr.netmask
                    if netmask == None and self.os_name == "darwin":
                        cmd_output = subprocess.check_output(['ifconfig', self.interface['name']]).decode()
                        netmask_search = re.search(r'netmask\s(0x[a-f0-9]+)', cmd_output)
                        if netmask_search:
                            netmask_hex = netmask_search.group(1)
                            netmask = str(ipaddress.IPv4Address(int(netmask_hex, 16)))
                    self.logger.debug(f"Found address {addr.address}/{netmask} on {self.interface['name']}")
                    self.interface['addr'] = addr.address
                    self.interface['cidr'] = ipaddress.ip_network(f"{addr.address}/{netmask}", strict=False)
                elif addr.family == socket.AF_INET6 and not addr.address.startswith("fe80"):
                    self.interface['addr6'] = addr.address
            # Windows-specific: Get interface indices using netsh
            if self.os_name == "windows":
                cmd = ["netsh", "interface", "ipv4", "show", "interfaces"]
                cmd_output = subprocess.check_output(cmd).decode()
                netsh_lines = cmd_output.split('\n')
                for line in netsh_lines:
                    parts = line.split()
                    # This should only be true for lines actually describing an interface
                    if len(parts) > 4 and parts[0].isdigit():
                        # Assuming the interface name is the last part of the line and index is the first
                        i = ' '.join(parts[4:])
                        if i == self.interface['name']:
                            self.interface['indx'] = parts[0]
        return

    def start(self):
        if not self.running:
            self.running = True
            self.initialize_socket()
            self.thread = threading.Thread(target=self.listen)
            self.thread.daemon= True
            self.thread.start()
            self.logger.debug(f"Routing agent listening for {self.interface['name']}")

    def stop(self):
        if self.running:
            self.running = False
            self.thread.join()  # Wait for the listen loop to finish
            if self.sock:
                self.sock.close()

    # Linux won't receive multicast if you bind to interface address - mc group is dest of packet, doesn't match bound addr.
    # It also mixes up multiple multicast sockets joined to the same group and all bound to 0.0.0.0 since there is
    # no way to differentiate which socket you wanted to receive that packet on.
    # So what they did is SO_BINDTODEVICE where you can actually specify a NIC on which to bind 0.0.0.0.
    # Windows doesn't have or need this, because you CAN bind to interface address and still receive multicast, even though packet dest is not your bound address.
    # Equivalent of SO_BINDTODEVICE might be IP_BOUND_IF or IP_RECVIF on mac? https://stackoverflow.com/questions/20616029/os-x-equivalent-of-so-bindtodevice
    # Neither is implemented in python.. so did some digging to find numerical equivalent which is 25 in base10.
##### WORKS but ugly.. trying other things below
    def initialize_socket(self):
        mreq = struct.pack("4s4s", socket.inet_aton(self.group), socket.inet_aton(self.interface['addr']))
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        if self.os_name == "darwin":
            self.mcsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
            self.mcsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.mcsock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(self.interface['addr']))
            self.mcsock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
        else:
            self.sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
        if self.os_name == "darwin":
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            self.mcsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            self.mcsock.bind((self.group, self.port))
            self.sock.bind((self.interface['addr'], self.port))
        elif self.os_name == "linux":
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, bytes(self.interface['name'], 'utf-8'))
            self.sock.bind(('', self.port))
        elif self.os_name == "windows":
            self.sock.bind((self.interface['addr'], self.port))

    def reset(self):
        try:
            if self.sock:
                self.sock.close()

        except OSError as e:
            self.logger.info(f"Error closing socket: {e}")

    def listen(self):
        while self.running:
            ready = select.select([self.sock], [], [], 1)[0]  # Reduced timeout to 1 second for responsiveness
            if ready:
                data, _ = self.sock.recvfrom(1024)
                self.handle_message(data.decode('utf-8'))
            if self.os_name == "darwin":
                ready = select.select([self.mcsock], [], [], 1)[0]  # Reduced timeout to 1 second for responsiveness
                if ready:
                    data, _ = self.mcsock.recvfrom(1024)
                    self.handle_message(data.decode('utf-8'))

    def handle_message(self, message):
        try:
            version, server_address, hash = message.split()
        except ValueError:
            self.logger.error("Invalid message format")
            return
        self.fetch_routes(server_address)
    
    def fetch_routes(self, server_address):
        self.api_client = ApiClient(server=server_address, api_key=self.api_key, api_secret=self.api_secret, api_cert_check=self.api_cert_check, bindaddr=self.interface['addr'], routes_blacklist=self.routes_blacklist, logger=self.logger)
        api_routes = self.api_client.get_routes()
        # ApiClient.get_routes() returns None when something has gone wrong (as opposed to and empty list when VNS3 is advertising no routes).
        # We should not blow away all routes if that happens.
        if api_routes is not None:
            current_routes = self.route_manager.get_current_routes()
            filtered_current_routes = []
            for route in current_routes:
                cidr, _ = route[:2]
                if not any(str(cidr).startswith(whitelisted_prefix) for whitelisted_prefix in self.routes_whitelist):
                    filtered_current_routes.append(route)
            # Diffs (ipaddress, gateway) route tuple lists and resolves system state to match API
            self.route_manager.update_routes(api_routes, filtered_current_routes)
        else:
            self.logger.error("Failed to get updated routes from API")

# Instantiated by RoutingAgent upon receiving a route update packet from VNS3
class ApiClient:
    def __init__(self, server, api_key, api_secret, api_cert_check, bindaddr, routes_blacklist, logger=None):
        self.logger = logger
        self.server = server
        self.api_key = api_key
        self.api_secret = api_secret
        self.bindaddr = bindaddr
        self.routes_blacklist = routes_blacklist
        if not api_cert_check:
            urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

    # Returns a list of tuples of (ipaddress object, gateway string) from VNS3 connected_subnets
    def get_routes(self):
        self.logger.debug(f"Requesting routes from VNS3 at {self.server}")
        url = f"https://{self.server}:8000/api/status/connected_subnets"
        credentials = f"{self.api_key}:{self.api_secret}"
        encoded_credentials = base64.b64encode(credentials.encode()).decode()
        headers = { "Accept": "text/plain", "Authorization": f"Basic {encoded_credentials}" }
        try:
            # Use SourceAddressAdapter to bind the request to a specific local IP address
            # Ugly, but works cross-platform
            from requests_toolbelt.adapters import source
            self.source = source.SourceAddressAdapter(self.bindaddr)
            with requests.Session() as session:
                session.mount('http://', self.source)
                response = session.get(url, headers=headers, verify=False)
            response.raise_for_status()
            # Convert API response into our (ipaddress route, gateway) format for diffing
            advertised_routes = self.parse_routes(response.text)
            # Filter blacklisted CIDRs before returning
            # OPENVPN STYLE STRING MATCHING
            filtered_routes = []

            #self.logger.debug(f"Checking advertised routes {advertised_routes} for blacklisted cidrs")
            for route in advertised_routes:
                # We don't check the gateway since blacklist isn't interface-specific
                cidr, _ = route
                # Trailing space on advertised CIDR so user can match specifically
                #ie, blacklisted "0.0.0.0/1 " will not match advertised "0.0.0.0/16 " where blacklisted "0.0.0.0/1" would match "0.0.0.0/16 "
                cidr_str = str(cidr) + " "
                # Check if the cidr starts with any prefix in the blacklist
                if not any(cidr_str.startswith(blacklisted_prefix) for blacklisted_prefix in self.routes_blacklist):
                    filtered_routes.append(route)
            return filtered_routes

        except requests.RequestException as e:
            self.logging.error(f"API request failed: {e}")
            return None

    # Takes the YAML response from the API and returns a list
    # of tuples containing ipaddress objects and a string for the gateway
    # similar to RouteManager.get_current_routes() for diffing
    def parse_routes(self, api_response):
        try:
            parsed_response = yaml.safe_load(api_response)
            #self.logger.debug(f"Received routes: {parsed_response}")
            routes = []
            for item in parsed_response:
                if isinstance(item, list) and len(item) == 2:
                    addr = ipaddress.ip_address(item[0])
                    if addr.version == 6: #ignore ipv6 for now
                        continue
                    # Convert netmask to CIDR notation and combine with IP to form an IPv4Network object
                    cidr = ipaddress.IPv4Network(f"{item[0]}/{item[1]}", strict=False)
                    routes.append((cidr, ipaddress.IPv4Address(self.server)))
            return routes
        except (yaml.YAMLError, ValueError) as e:
            self.logger.error(f"Error parsing response: {e}")
            return routes

# Instantiated by RoutingAgent; each instance operates on a given interface.
class RouteManager:
    def __init__(self, interface, update_wg=False, logger=None):
        self.logger = logger
        self.os_name = platform.system().lower()
        self.interface = interface
        self.update_wg = update_wg and self.is_wg()
        self.logger.debug(f"RouteManager initialized for interface {self.interface}")

    def is_wg(self):
        try:
            cmd = ["sudo", "wg"]
            cmd_output = subprocess.check_output(cmd).decode()
            for line in cmd_output.splitlines():
                if f"interface: {self.interface['name']}" in line:
                    return True
            return False
        except:
            return False

    # Diffs the two lists and resolves system state to match routes given by API
    def update_routes(self, api_routes, current_routes):
        #self.logger.debug(f"api_routes: {api_routes}")
        #self.logger.debug(f"current_routes: {current_routes}")
        # Modify the comparisons to use only the first two elements of the tuples in current_routes.
        routes_to_add = [route for route in api_routes if route not in [current_route[:2] for current_route in current_routes] ]
        routes_to_remove = [current_route for current_route in current_routes if current_route[:2] not in api_routes]
        #self.logger.debug(f"routes_to_add: {routes_to_add}")
        #self.logger.debug(f"routes_to_remove: {routes_to_remove}")
        for route in routes_to_remove:
            self.delete_route(route)
        for route in routes_to_add:
            self.add_route(route)    

    def correct_cidr(self, destination):
        if destination == 'default':
            return '0.0.0.0/0'
        parts = destination.split('/')
        base_ip = parts[0]
        octets = base_ip.split('.')
        cidr_suffix = 32
        while octets[-1] == '' or octets[-1] == '0':
            del(octets[-1])
        # Fill in missing octets with '0'
        while len(octets) < 4:
            octets.append('0')
            cidr_suffix -= 8
        if len(parts) == 2:
            # If CIDR suffix is provided, use it instead of calculated one
            cidr_suffix = parts[1]
        # Combine the parts into a full CIDR notation
        corrected_ip = '.'.join(octets) + '/' + str(cidr_suffix)
        return str(ipaddress.ip_network(corrected_ip, strict=False))

    # Windows sucks, GPT doesn't
    def parse_netsh_output(self, route_output, ip_version):
        routes = []
        lines = route_output.splitlines()
        for line in lines[2:]:  # Skip headers
            parts = line.split()
            if len(parts) >= 5:
                # Common format: Publish, Type, Met, Prefix, Idx, Gateway/Interface Name
                prefix = parts[3]
                interface_idx = parts[4]
                # If a gateway is defined, it shows up in this column. Otherwise windows reports the interface addr which basically means link-local
                gateway_or_interface = parts[5]

                if interface_idx != str(self.interface['indx']):
                    continue

                if ip_version == 4:
                    try:
                        network = ipaddress.IPv4Network(prefix, strict=False)
                        # Check if the gateway_or_interface field contains an IP address
                        try:
                            gateway = ipaddress.IPv4Address(gateway_or_interface)
                        except ipaddress.AddressValueError:
                            gateway = None  # Not a valid IP, considered on-link
                        routes.append((network, gateway))
                    except ValueError:
                        pass  # Skip invalid entries
                elif ip_version == 6:
                    try:
                        network = ipaddress.IPv6Network(prefix, strict=False)
                        # IPv6 handling similar to IPv4, considering on-link for non-IP fields
                        try:
                            gateway = ipaddress.IPv6Address(gateway_or_interface)
                        except ipaddress.AddressValueError:
                            gateway = None  # Not a valid IP, considered on-link
                        routes.append((network, gateway))
                    except ValueError:
                        pass  # Skip invalid entries
        return routes

    # Returns system routes on our interface, except the link-local route and multicast route.
    # Used for differential comparison to result of get_connected_subnets.
    # Returns a list of tuples containing an ipaddress object and gateway.
    def get_current_routes(self):
        routes = []
        multicast_network = ipaddress.IPv4Network("224.0.0.0/4", strict=False)
        global_broadcast = ipaddress.IPv4Network("255.255.255.255/32", strict=False)
        if self.os_name == "linux":
            try:
                cmd = ["ip", "route", "show", "dev", self.interface['name']]
                cmd_output = subprocess.check_output(cmd).decode()
                for line in cmd_output.splitlines():
                    parts = line.split()
                    cidr = parts[0]
                    gateway = None
                    if 'via' in parts:
                        gateway_index = parts.index('via') + 1
                        gateway = parts[gateway_index]
                        gateway = ipaddress.IPv4Address(gateway)
                    cidr = ipaddress.ip_network(cidr, strict=False)
                    if cidr != self.interface['cidr'] and cidr != multicast_network and not cidr.subnet_of(multicast_network) and cidr != global_broadcast:
                        routes.append((cidr, gateway))
            except subprocess.CalledProcessError as e:
                logging.error(f"Error fetching routes in Linux: {e}")
                return None

        # netsh is way easier to parse than `route print` or `netstat -nr` IMO
        # ...Even considering the large function above making this look nice
        elif self.os_name == "windows":
            try:
                # Fetching IPv4 routing table
                cmd = ["netsh", "interface", "ipv4", "show", "route"]
                cmd_output = subprocess.check_output(cmd).decode()
                ipv4_routes = self.parse_netsh_output(cmd_output, 4)
                # Fetching IPv6 routing table
                cmd = ["netsh", "interface", "ipv6", "show", "route"]
                cmd_output = subprocess.check_output(cmd).decode()
                ipv6_routes = self.parse_netsh_output(cmd_output, 6)
                # Filtering for multicast or link-local network
                for cidr, gateway in ipv4_routes:
                    if cidr != self.interface['cidr'] and cidr != multicast_network and not cidr.subnet_of(multicast_network) and cidr != global_broadcast:
                        routes.append((cidr, gateway))
            except subprocess.CalledProcessError as e:
                self.logger.error(f"Error fetching routes in Windows: {e}")
                return None
        elif self.os_name == "darwin":
            try:
                cmd = ["netstat", "-nr", "-f", "inet"]
                cmd_output = subprocess.check_output(cmd).decode()
                headers_found = False
                for line in cmd_output.splitlines():
                    parts = line.split()
                    if "Destination" in line:
                        # Found header row - populate index vars for fetching route parts
                        cidr_index = parts.index('Destination')
                        gateway_index = parts.index('Gateway')
                        ifname_index = parts.index('Netif')
                        headers_found = True
                        continue
                    if not headers_found:
                        # Have not seen header row yet
                        continue
                    if parts[ifname_index] != self.interface['name']:
                        # Route is not on our interface
                        continue
                    dest = parts[cidr_index]
                    gateway = parts[gateway_index]
                    try:
                        gateway = ipaddress.IPv4Address(gateway)
                    except:
                        gateway = None
                    orig_str = dest
                    cidr = self.correct_cidr(dest)
                    cidr = ipaddress.ip_network(cidr, strict=False)
                    # MacOS adds a route to local addr/32 via self for some reason - leave it be
                    if cidr.prefixlen == 32 and ipaddress.IPv4Address(self.interface['addr']) in cidr:
                        continue
                    if cidr != self.interface['cidr'] and cidr != multicast_network and not cidr.subnet_of(multicast_network) and cidr != global_broadcast:
                        routes.append((cidr, gateway, orig_str))
            except subprocess.CalledProcessError as e:
                self.logger.error(f"Error fetching routes in MacOS: {e}")
                return None

        return routes

    # Takes in one of those tuples from a route list
    def add_route(self, route):
        cidr, gateway = route
        self.logger.info(f"Adding route: {cidr}")
        if self.os_name == "linux":
            cmd = ["sudo", "ip", "route", "add", str(cidr), "dev", self.interface['name']]
            if gateway:
                cmd += ["via", str(gateway)]
        elif self.os_name == "windows":
            cmd = ["route", "add", str(cidr), "mask", str(cidr.netmask), str(gateway), "METRIC", "1", "if", str(self.interface['indx'])]
        elif self.os_name == "darwin":
            cmd = ["sudo", "route", "add", str(cidr)]
            if gateway:
                cmd += [str(gateway)]
            cmd += ["-ifscope", self.interface['name']]
        else:
            self.logger.error(f"Unsupported OS: {self.os_name}")
            return
        try:
            cmd_output = subprocess.check_output(cmd).decode()
        except subprocess.CalledProcessError as e:
            self.logger.error(f"Error executing command: {e}")

    def delete_route(self, route):
        cidr, gateway = route[:2]
        self.logger.info(f"Deleting route: {cidr}")
        if self.os_name == "linux":
            cmd = ["sudo", "ip", "route", "del", str(cidr), "dev", self.interface['name']]
        elif self.os_name == "windows":
            cmd = ["route", "delete", str(cidr)]
        elif self.os_name == "darwin":
            cmd = ["sudo", "route", "delete"]
            if '/' in route[2] or '.' in route[2]:
                cmd += [ route[2] ]
            else:
                cmd += [ str(cidr) ]
            if gateway:
                cmd += [str(gateway)]
            cmd += [ "-ifscope", self.interface['name'] ]
        else:
            self.logger.error(f"Unsupported OS: {self.os_name}")
            return
        try:
            cmd_output = subprocess.check_output(cmd).decode()
        except subprocess.CalledProcessError as e:
            self.logger.error(f"Error executing command: {e}")


class RAThreadManager:
    def __init__(self, logger):
        self.agents = {}
        self.logger = logger

    # Returns a list of the names of 'up' interfaces matching NIC_PREFIXES which have >0 usable addresses
    def get_active_interfaces(self):
        active_interface_names = []
        # Iterate over all interfaces and check their status
        ifaces = psutil.net_if_addrs()
        for interface in ifaces.keys():
            # If this interface is 'up' and matches a nic_prefix
            if psutil.net_if_stats()[interface].isup and any(interface.startswith(prefix) for prefix in NIC_PREFIXES):
                # Check that it has a non-loopback IPv4 address we can use
                if any(addr.family == socket.AF_INET and not addr.address.startswith("127") for addr in ifaces[interface]):
                    active_interface_names.append(interface)
                # Accept IPv6-only interfaces too
#               elif any(addr.family == socket.AF_INET6 and not addr.address.startswith("fe80") for addr in ifaces[interface]):
#                   active_interface_names.append(interface)
        return active_interface_names

    # Makes sure there is one RA thread for each interface
    def update_state(self):
        active_interfaces = self.get_active_interfaces()
        current_interfaces = set(self.agents.keys())
        # Start agents for new interfaces
        for interface_name in active_interfaces:
            if interface_name not in current_interfaces:
                self.logger.info(f"Starting listener for new interface: {interface_name}")
                self.agents[interface_name] = RoutingAgent(interface_name, logger=self.logger)
                self.agents[interface_name].start()
            else:
                # Check if existing thread is still alive, stop+delete+recreate if not
                self.logger.info(f"Check existing thread for interface: {interface_name}")
                if not self.agents[interface_name].thread.is_alive():
                    self.logger.warning(f"Thread for {interface_name} died, restarting...")
                    self.agents[interface_name].stop()
                    del self.agents[interface_name]
                    self.agents[interface_name] = RoutingAgent(interface_name, logger=self.logger)
                    self.agents[interface_name].start()
        # Stop and remove agents for interfaces that are no longer active
        for interface_name in list(current_interfaces):  # List conversion for safe modification
            if interface_name not in active_interfaces:
                self.logger.info(f"Stopping listener for removed interface: {interface_name}")
                self.agents[interface_name].stop()
                del self.agents[interface_name]

    def reset(self):
        for agent in self.agents.values():
            agent.reset()