diff --git a/README.md b/README.md index 4ca5156..3d598dd 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,14 @@ We have implemented PIM-DM specification ([RFC3973](https://tools.ietf.org/html/ This repository stores the implementation of this protocol. The implementation is written in Python language and is destined to Linux systems. +Additionally, IGMPv2 and MLDv1 are implemented alongside with PIM-DM to detect interest of hosts. + # Requirements - Linux machine - - Python3 (we have written all code to be compatible with at least Python v3.2) + - Unicast routing protocol + - Python3 (we have written all code to be compatible with at least Python v3.3) - pip (to install all dependencies) - tcpdump @@ -41,6 +44,8 @@ In order to start the protocol you first need to explicitly start it. This will sudo pim-dm -start ``` +IPv4 and IPv6 multicast is supported. By default all commands will be executed on IPv4 daemon. To execute a command on the IPv6 daemon use `-6`. + #### Add interface @@ -49,21 +54,27 @@ After starting the protocol process you can enable the protocol in specific inte - To enable PIM-DM without State-Refresh, in a given interface, you need to run the following command: ``` - sudo pim-dm -ai INTERFACE_NAME + sudo pim-dm -ai INTERFACE_NAME [-4 | -6] ``` - To enable PIM-DM with State-Refresh, in a given interface, you need to run the following command: ``` - sudo pim-dm -aisf INTERFACE_NAME + sudo pim-dm -aisf INTERFACE_NAME [-4 | -6] ``` -- To enable IGMP in a given interface, you need to run the following command: +- To enable IGMP/MLD in a given interface, you need to run the following command: + - IGMP: ``` sudo pim-dm -aiigmp INTERFACE_NAME ``` + - MLD: + ``` + sudo pim-dm -aimld INTERFACE_NAME + ``` + #### Remove interface To remove a previously added interface, you need run the following commands: @@ -71,15 +82,20 @@ To remove a previously added interface, you need run the following commands: - To remove a previously added PIM-DM interface: ``` - sudo pim-dm -ri INTERFACE_NAME + sudo pim-dm -ri INTERFACE_NAME [-4 | -6] ``` -- To remove a previously added IGMP interface: - +- To remove a previously added IGMP/MLD interface: + - IGMP: ``` sudo pim-dm -riigmp INTERFACE_NAME ``` + - MLD: + ``` + sudo pim-dm -rimld INTERFACE_NAME + ``` + #### Stop protocol process @@ -96,31 +112,31 @@ We have built some list commands that can be used to check the "internals" of th - #### List interfaces: - Show all router interfaces and which ones have PIM-DM and IGMP enabled. For IGMP enabled interfaces check the IGMP Querier state. + Show all router interfaces and which ones have PIM-DM and IGMP/MLD enabled. For IGMP/MLD enabled interfaces you can check the Querier state. ``` - sudo pim-dm -li + sudo pim-dm -li [-4 | -6] ``` - #### List neighbors Verify neighbors that have established a neighborhood relationship. ``` - sudo pim-dm -ln + sudo pim-dm -ln [-4 | -6] ``` - #### List state List all state machines and corresponding state of all trees that are being monitored. Also list IGMP state for each group being monitored. ``` - sudo pim-dm -ls + sudo pim-dm -ls [-4 | -6] ``` - #### Multicast Routing Table - List Linux Multicast Routing Table (equivalent to `ip mroute -show`) + List Linux Multicast Routing Table (equivalent to `ip mroute show`) ``` - sudo pim-dm -mr + sudo pim-dm -mr [-4 | -6] ``` @@ -131,15 +147,10 @@ In order to determine which commands and corresponding arguments are available y pim-dm -h ``` - or - - ``` - pim-dm --help - ``` ## Change settings -Files tree/globals.py and igmp/igmp_globals.py store all timer values and some configurations regarding IGMP and the PIM-DM. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface. +Files tree/globals.py, igmp/igmp_globals.py and mld/mld_globals.py store all timer values and some configurations regarding PIM-DM, IGMP and MLD. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface. ## Tests @@ -151,4 +162,4 @@ We have performed tests to our PIM-DM implementation. You can check on the corre - [Test_PIM_Assert](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Assert) - Topology used to test the election of the AssertWinner. - [Test_PIM_Join_Prune_Graft](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_Join_Prune_Graft) - Topology used to test the Pruning and Grafting of the multicast distribution tree. - [Test_PIM_StateRefresh](https://github.com/pedrofran12/pim_dm/tree/Test_PIM_StateRefresh) - Topology used to test PIM-DM State Refresh. -- [Test_IGMP](https://github.com/pedrofran12/hpim_dm/tree/Test_IGMP) - Topology used to test our IGMPv2 implementation. +- [Test_IGMP](https://github.com/pedrofran12/pim_dm/tree/Test_IGMP) - Topology used to test our IGMPv2 implementation. diff --git a/pimdm/Interface.py b/pimdm/Interface.py index c11c5ae..b66c9b3 100644 --- a/pimdm/Interface.py +++ b/pimdm/Interface.py @@ -18,8 +18,11 @@ def __init__(self, interface_name, recv_socket, send_socket, vif_index): self._recv_socket = recv_socket self.interface_enabled = False - def _enable(self): + """ + Enable this interface + This will start a thread to be executed in the background to be used in the reception of control packets + """ self.interface_enabled = True # run receive method in background receive_thread = threading.Thread(target=self.receive) @@ -27,24 +30,39 @@ def _enable(self): receive_thread.start() def receive(self): + """ + Method that will be executed in the background for the reception of control packets + """ while self.interface_enabled: try: - (raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024) + (raw_bytes, ancdata, _, src_addr) = self._recv_socket.recvmsg(256 * 1024, 500) if raw_bytes: - self._receive(raw_bytes) + self._receive(raw_bytes, ancdata, src_addr) except Exception: traceback.print_exc() continue @abstractmethod - def _receive(self, raw_bytes): + def _receive(self, raw_bytes, ancdata, src_addr): + """ + Subclass method to be implemented + This method will be invoked whenever a new control packet is received + """ raise NotImplementedError def send(self, data: bytes, group_ip: str): + """ + Send a control packet through this interface + Explicitly destined to group_ip (can be unicast or multicast IP) + """ if self.interface_enabled and data: self._send_socket.sendto(data, (group_ip, 0)) def remove(self): + """ + This interface is no longer active.... + Clear all state regarding it + """ self.interface_enabled = False try: self._recv_socket.shutdown(socket.SHUT_RDWR) @@ -54,8 +72,14 @@ def remove(self): self._send_socket.close() def is_enabled(self): + """ + Verify if this interface is enabled + """ return self.interface_enabled @abstractmethod def get_ip(self): + """ + Get IP of this interface + """ raise NotImplementedError diff --git a/pimdm/InterfaceIGMP.py b/pimdm/InterfaceIGMP.py index e34a964..bf9b6fc 100644 --- a/pimdm/InterfaceIGMP.py +++ b/pimdm/InterfaceIGMP.py @@ -4,7 +4,7 @@ from ctypes import create_string_buffer, addressof import netifaces from pimdm.Interface import Interface -from pimdm.Packet.ReceivedPacket import ReceivedPacket +from pimdm.packet.ReceivedPacket import ReceivedPacket from pimdm.igmp.igmp_globals import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query if not hasattr(socket, 'SO_BINDTODEVICE'): socket.SO_BINDTODEVICE = 25 @@ -48,18 +48,20 @@ def __init__(self, interface_name: str, vif_index: int): self.interface_state = RouterState(self) super()._enable() - def get_ip(self): return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr'] @property def ip_interface(self): + """ + Get IP of this interface + """ return self.get_ip() def send(self, data: bytes, address: str="224.0.0.1"): super().send(data, address) - def _receive(self, raw_bytes): + def _receive(self, raw_bytes, ancdata, src_addr): if raw_bytes: raw_bytes = raw_bytes[14:] packet = ReceivedPacket(raw_bytes, self) @@ -91,7 +93,8 @@ def receive_leave_group(self, packet): def receive_membership_query(self, packet): ip_dst = packet.ip_header.ip_dst igmp_group = packet.payload.group_address - if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"): + if (IPv4Address(igmp_group).is_multicast and ip_dst == igmp_group) or \ + (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"): self.interface_state.receive_query(packet) def receive_unknown_type(self, packet): diff --git a/pimdm/InterfaceMLD.py b/pimdm/InterfaceMLD.py new file mode 100644 index 0000000..37fadaf --- /dev/null +++ b/pimdm/InterfaceMLD.py @@ -0,0 +1,199 @@ +import socket +import struct +import netifaces +import ipaddress +from socket import if_nametoindex +from ipaddress import IPv6Address +from .Interface import Interface +from .packet.ReceivedPacket import ReceivedPacket_v6 +from .mld.mld_globals import MULTICAST_LISTENER_QUERY_TYPE, MULTICAST_LISTENER_DONE_TYPE, MULTICAST_LISTENER_REPORT_TYPE +from ctypes import create_string_buffer, addressof + +ETH_P_IPV6 = 0x86DD # IPv6 over bluebook +SO_ATTACH_FILTER = 26 +ICMP6_FILTER = 1 +IPV6_ROUTER_ALERT = 22 + + +def ICMP6_FILTER_SETBLOCKALL(): + return struct.pack("I"*8, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF) + + +def ICMP6_FILTER_SETPASS(type, filterp): + return filterp[:type >> 5] + (bytes([(filterp[type >> 5] & ~(1 << ((type) & 31)))])) + filterp[(type >> 5) + 1:] + + +class InterfaceMLD(Interface): + IPv6_LINK_SCOPE_ALL_NODES = IPv6Address("ff02::1") + IPv6_LINK_SCOPE_ALL_ROUTERS = IPv6Address("ff02::2") + IPv6_ALL_ZEROS = IPv6Address("::") + + FILTER_MLD = [ + struct.pack('HBBI', 0x28, 0, 0, 0x0000000c), + struct.pack('HBBI', 0x15, 0, 9, 0x000086dd), + struct.pack('HBBI', 0x30, 0, 0, 0x00000014), + struct.pack('HBBI', 0x15, 0, 7, 0x00000000), + struct.pack('HBBI', 0x30, 0, 0, 0x00000036), + struct.pack('HBBI', 0x15, 0, 5, 0x0000003a), + struct.pack('HBBI', 0x30, 0, 0, 0x0000003e), + struct.pack('HBBI', 0x15, 2, 0, 0x00000082), + struct.pack('HBBI', 0x15, 1, 0, 0x00000083), + struct.pack('HBBI', 0x15, 0, 1, 0x00000084), + struct.pack('HBBI', 0x6, 0, 0, 0x00040000), + struct.pack('HBBI', 0x6, 0, 0, 0x00000000), + ] + + def __init__(self, interface_name: str, vif_index: int): + # SEND SOCKET + s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6) + + # set socket output interface + if_index = if_nametoindex(interface_name) + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, struct.pack('@I', if_index)) + + """ + # set ICMP6 filter to only receive MLD packets + icmp6_filter = ICMP6_FILTER_SETBLOCKALL() + icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_QUERY_TYPE, icmp6_filter) + icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_REPORT_TYPE, icmp6_filter) + icmp6_filter = ICMP6_FILTER_SETPASS(MULTICAST_LISTENER_DONE_TYPE, icmp6_filter) + s.setsockopt(socket.IPPROTO_ICMPV6, ICMP6_FILTER, icmp6_filter) + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_RECVPKTINFO, True) + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, False) + s.setsockopt(socket.IPPROTO_IPV6, self.IPV6_ROUTER_ALERT, 0) + rcv_s = s + """ + + ip_interface = "::" + for if_addr in netifaces.ifaddresses(interface_name)[netifaces.AF_INET6]: + ip_interface = if_addr["addr"] + if ipaddress.IPv6Address(ip_interface.split("%")[0]).is_link_local: + # bind to interface + s.bind(socket.getaddrinfo(ip_interface, None, 0, socket.SOCK_RAW, 0, socket.AI_PASSIVE)[0][4]) + ip_interface = ip_interface.split("%")[0] + break + self.ip_interface = ip_interface + + # RECEIVE SOCKET + rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(ETH_P_IPV6)) + + # receive only MLD packets by setting a BPF filter + bpf_filter = b''.join(InterfaceMLD.FILTER_MLD) + b = create_string_buffer(bpf_filter) + mem_addr_of_filters = addressof(b) + fprog = struct.pack('HL', len(InterfaceMLD.FILTER_MLD), mem_addr_of_filters) + rcv_s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog) + + # bind to interface + rcv_s.bind((interface_name, ETH_P_IPV6)) + + super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=s, vif_index=vif_index) + self.interface_enabled = True + from .mld.RouterState import RouterState + self.interface_state = RouterState(self) + super()._enable() + + @staticmethod + def _get_address_family(): + return socket.AF_INET6 + + def get_ip(self): + return self.ip_interface + + def send(self, data: bytes, address: str = "FF02::1"): + # send router alert option + cmsg_level = socket.IPPROTO_IPV6 + cmsg_type = socket.IPV6_HOPOPTS + cmsg_data = b'\x3a\x00\x05\x02\x00\x00\x01\x00' + self._send_socket.sendmsg([data], [(cmsg_level, cmsg_type, cmsg_data)], 0, (address, 0)) + + """ + def receive(self): + while self.interface_enabled: + try: + (raw_bytes, ancdata, _, src_addr) = self._recv_socket.recvmsg(256 * 1024, 500) + if raw_bytes: + self._receive(raw_bytes, ancdata, src_addr) + except Exception: + import traceback + traceback.print_exc() + continue + """ + + def _receive(self, raw_bytes, ancdata, src_addr): + if raw_bytes: + raw_bytes = raw_bytes[14:] + src_addr = (socket.inet_ntop(socket.AF_INET6, raw_bytes[8:24]),) + print("MLD IP_SRC bf= ", src_addr) + dst_addr = raw_bytes[24:40] + (next_header,) = struct.unpack("B", raw_bytes[6:7]) + print("NEXT HEADER:", next_header) + payload_starts_at_len = 40 + if next_header == 0: + # Hop by Hop options + (next_header,) = struct.unpack("B", raw_bytes[40:41]) + if next_header != 58: + return + (hdr_ext_len,) = struct.unpack("B", raw_bytes[payload_starts_at_len +1:payload_starts_at_len + 2]) + if hdr_ext_len > 0: + payload_starts_at_len = payload_starts_at_len + 1 + hdr_ext_len*8 + else: + payload_starts_at_len = payload_starts_at_len + 8 + + raw_bytes = raw_bytes[payload_starts_at_len:] + ancdata = [(socket.IPPROTO_IPV6, socket.IPV6_PKTINFO, dst_addr)] + print("RECEIVE MLD") + print("ANCDATA: ", ancdata, "; SRC_ADDR: ", src_addr) + packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 58, self) + ip_src = packet.ip_header.ip_src + print("MLD IP_SRC = ", ip_src) + if not (ip_src == "::" or IPv6Address(ip_src).is_multicast): + self.PKT_FUNCTIONS.get(packet.payload.get_mld_type(), InterfaceMLD.receive_unknown_type)(self, packet) + """ + def _receive(self, raw_bytes, ancdata, src_addr): + if raw_bytes: + packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 58, self) + self.PKT_FUNCTIONS[packet.payload.get_mld_type(), InterfaceMLD.receive_unknown_type](self, packet) + """ + ########################################### + # Recv packets + ########################################### + def receive_multicast_listener_report(self, packet): + print("RECEIVE MULTICAST LISTENER REPORT") + ip_dst = packet.ip_header.ip_dst + mld_group = packet.payload.group_address + ipv6_group = IPv6Address(mld_group) + ipv6_dst = IPv6Address(ip_dst) + if ipv6_dst == ipv6_group and ipv6_group.is_multicast: + self.interface_state.receive_report(packet) + + def receive_multicast_listener_done(self, packet): + print("RECEIVE MULTICAST LISTENER DONE") + ip_dst = packet.ip_header.ip_dst + mld_group = packet.payload.group_address + if IPv6Address(ip_dst) == self.IPv6_LINK_SCOPE_ALL_ROUTERS and IPv6Address(mld_group).is_multicast: + self.interface_state.receive_done(packet) + + def receive_multicast_listener_query(self, packet): + print("RECEIVE MULTICAST LISTENER QUERY") + ip_dst = packet.ip_header.ip_dst + mld_group = packet.payload.group_address + ipv6_group = IPv6Address(mld_group) + ipv6_dst = IPv6Address(ip_dst) + if (ipv6_group.is_multicast and ipv6_dst == ipv6_group) or\ + (ipv6_dst == self.IPv6_LINK_SCOPE_ALL_NODES and ipv6_group == self.IPv6_ALL_ZEROS): + self.interface_state.receive_query(packet) + + def receive_unknown_type(self, packet): + raise Exception("UNKNOWN MLD TYPE: " + str(packet.payload.get_mld_type())) + + PKT_FUNCTIONS = { + MULTICAST_LISTENER_REPORT_TYPE: receive_multicast_listener_report, + MULTICAST_LISTENER_DONE_TYPE: receive_multicast_listener_done, + MULTICAST_LISTENER_QUERY_TYPE: receive_multicast_listener_query, + } + + ################## + def remove(self): + super().remove() + self.interface_state.remove() diff --git a/pimdm/InterfacePIM.py b/pimdm/InterfacePIM.py index 50d03dd..b54f556 100644 --- a/pimdm/InterfacePIM.py +++ b/pimdm/InterfacePIM.py @@ -1,19 +1,19 @@ +import socket import random -from pimdm.Interface import Interface -from pimdm.Packet.ReceivedPacket import ReceivedPacket -from pimdm import Main +import logging +import netifaces import traceback -from pimdm.RWLock.RWLock import RWLockWrite -from pimdm.Packet.PacketPimHelloOptions import * -from pimdm.Packet.PacketPimHello import PacketPimHello -from pimdm.Packet.PacketPimHeader import PacketPimHeader -from pimdm.Packet.Packet import Packet -from pimdm.utils import HELLO_HOLD_TIME_TIMEOUT from threading import Timer -from pimdm.tree.globals import REFRESH_INTERVAL -import socket -import netifaces -import logging + +from pimdm.Interface import Interface +from pimdm.packet.ReceivedPacket import ReceivedPacket +from pimdm import Main +from pimdm.rwlock.RWLock import RWLockWrite +from pimdm.packet.PacketPimHelloOptions import * +from pimdm.packet.PacketPimHello import PacketPimHello +from pimdm.packet.PacketPimHeader import PacketPimHeader +from pimdm.packet.Packet import Packet +from pimdm.tree.globals import HELLO_HOLD_TIME_TIMEOUT, REFRESH_INTERVAL class InterfacePim(Interface): @@ -83,18 +83,37 @@ def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:boo self.force_send_hello() def get_ip(self): + """ + Get IP of this interface + """ return self.ip_interface - def _receive(self, raw_bytes): + @staticmethod + def get_kernel(): + """ + Get Kernel object + """ + return Main.kernel + + def _receive(self, raw_bytes, ancdata, src_addr): + """ + Interface received a new control packet + """ if raw_bytes: packet = ReceivedPacket(raw_bytes, self) - self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet) + self.PKT_FUNCTIONS.get(packet.payload.get_pim_type(), InterfacePim.receive_unknown)(self, packet) def send(self, data: bytes, group_ip: str=MCAST_GRP): + """ + Send a new packet destined to group_ip IP + """ super().send(data=data, group_ip=group_ip) #Random interval for initial Hello message on bootup or triggered Hello message to a rebooting neighbor def force_send_hello(self): + """ + Force the transmission of a new Hello message + """ if self.hello_timer is not None: self.hello_timer.cancel() @@ -103,6 +122,10 @@ def force_send_hello(self): self.hello_timer.start() def send_hello(self): + """ + Send a new Hello message + Include in it the HelloHoldTime and GenerationID + """ self.interface_logger.debug('Send Hello message') self.hello_timer.cancel() @@ -125,6 +148,10 @@ def send_hello(self): self.hello_timer.start() def remove(self): + """ + Remove this interface + Clear all state + """ self.hello_timer.cancel() self.hello_timer = None @@ -136,17 +163,20 @@ def remove(self): packet = Packet(payload=ph) self.send(packet.bytes()) - Main.kernel.interface_change_number_of_neighbors() + self.get_kernel().interface_change_number_of_neighbors() super().remove() def check_number_of_neighbors(self): has_neighbors = len(self.neighbors) > 0 if has_neighbors != self._had_neighbors: self._had_neighbors = has_neighbors - Main.kernel.interface_change_number_of_neighbors() + self.get_kernel().interface_change_number_of_neighbors() def new_or_reset_neighbor(self, neighbor_ip): - Main.kernel.new_or_reset_neighbor(self.vif_index, neighbor_ip) + """ + React to new neighbor or restart of known neighbor + """ + self.get_kernel().new_or_reset_neighbor(self.vif_index, neighbor_ip) ''' def add_neighbor(self, ip, random_number, hello_hold_time): @@ -160,27 +190,44 @@ def add_neighbor(self, ip, random_number, hello_hold_time): ''' def get_neighbors(self): + """ + Get list of known neighbors + """ with self.neighbors_lock.genRlock(): return self.neighbors.values() def get_neighbor(self, ip): + """ + Get specific neighbor by its IP + """ with self.neighbors_lock.genRlock(): return self.neighbors.get(ip) def remove_neighbor(self, ip): + """ + Remove known neighbor + """ with self.neighbors_lock.genWlock(): del self.neighbors[ip] self.interface_logger.debug("Remove neighbor: " + ip) self.check_number_of_neighbors() def set_state_refresh_capable(self, value): + """ + Change StateRefresh capability of interface + """ self._state_refresh_capable = value def is_state_refresh_enabled(self): + """ + Check if state refresh is enabled + """ return self._state_refresh_capable - # check if Interface is StateRefreshCapable def is_state_refresh_capable(self): + """ + Check StateRefresh capability of interface neighbors + """ with self.neighbors_lock.genWlock(): if len(self.neighbors) == 0: return False @@ -214,6 +261,9 @@ def change_interface(self): # Recv packets ########################################### def receive_hello(self, packet): + """ + Receive an Hello packet + """ ip = packet.ip_header.ip_src print("ip = ", ip) options = packet.payload.payload.get_options() @@ -226,7 +276,6 @@ def receive_hello(self, packet): state_refresh_capable = (21 in options) - with self.neighbors_lock.genWlock(): if ip not in self.neighbors: if hello_hold_time == 0: @@ -244,17 +293,23 @@ def receive_hello(self, packet): neighbor.receive_hello(generation_id, hello_hold_time, state_refresh_capable) def receive_assert(self, packet): + """ + Receive an Assert packet + """ pkt_assert = packet.payload.payload # type: PacketPimAssert source = pkt_assert.source_address group = pkt_assert.multicast_group_address source_group = (source, group) try: - Main.kernel.get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet) + self.get_kernel().get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet) except: traceback.print_exc() def receive_join_prune(self, packet): + """ + Receive Join/Prune packet + """ pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune join_prune_groups = pkt_join_prune.groups @@ -266,7 +321,7 @@ def receive_join_prune(self, packet): for source_address in joined_src_addresses: source_group = (source_address, multicast_group) try: - Main.kernel.get_routing_entry(source_group).recv_join_msg(self.vif_index, packet) + self.get_kernel().get_routing_entry(source_group).recv_join_msg(self.vif_index, packet) except: traceback.print_exc() continue @@ -274,12 +329,15 @@ def receive_join_prune(self, packet): for source_address in pruned_src_addresses: source_group = (source_address, multicast_group) try: - Main.kernel.get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet) + self.get_kernel().get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet) except: traceback.print_exc() continue def receive_graft(self, packet): + """ + Receive Graft packet + """ pkt_join_prune = packet.payload.payload # type: PacketPimGraft join_prune_groups = pkt_join_prune.groups @@ -290,12 +348,15 @@ def receive_graft(self, packet): for source_address in joined_src_addresses: source_group = (source_address, multicast_group) try: - Main.kernel.get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet) + self.get_kernel().get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet) except: traceback.print_exc() continue def receive_graft_ack(self, packet): + """ + Receive an GraftAck packet + """ pkt_join_prune = packet.payload.payload # type: PacketPimGraftAck join_prune_groups = pkt_join_prune.groups @@ -306,12 +367,15 @@ def receive_graft_ack(self, packet): for source_address in joined_src_addresses: source_group = (source_address, multicast_group) try: - Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet) + self.get_kernel().get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet) except: traceback.print_exc() continue def receive_state_refresh(self, packet): + """ + Receive an StateRefresh packet + """ if not self.is_state_refresh_enabled(): return pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh @@ -320,10 +384,15 @@ def receive_state_refresh(self, packet): group = pkt_state_refresh.multicast_group_adress source_group = (source, group) try: - Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet) + self.get_kernel().get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet) except: traceback.print_exc() + def receive_unknown(self, packet): + """ + Receive an unknown packet + """ + raise Exception("Unknown PIM type: " + str(packet.payload.get_pim_type())) PKT_FUNCTIONS = { 0: receive_hello, diff --git a/pimdm/InterfacePIM6.py b/pimdm/InterfacePIM6.py new file mode 100644 index 0000000..8ba612e --- /dev/null +++ b/pimdm/InterfacePIM6.py @@ -0,0 +1,91 @@ +import socket +import random +import struct +import logging +import ipaddress +import netifaces +from pimdm import Main +from socket import if_nametoindex +from pimdm.Interface import Interface +from .InterfacePIM import InterfacePim +from pimdm.rwlock.RWLock import RWLockWrite +from pimdm.packet.ReceivedPacket import ReceivedPacket_v6 + + +class InterfacePim6(InterfacePim): + MCAST_GRP = "ff02::d" + + def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:bool=False): + # generation id + self.generation_id = random.getrandbits(32) + + # When PIM is enabled on an interface or when a router first starts, the Hello Timer (HT) + # MUST be set to random value between 0 and Triggered_Hello_Delay + self.hello_timer = None + + # state refresh capable + self._state_refresh_capable = state_refresh_capable + self._neighbors_state_refresh_capable = False + + # todo: lan delay enabled + self._lan_delay_enabled = False + + # todo: propagation delay + self._propagation_delay = self.PROPAGATION_DELAY + + # todo: override interval + self._override_interval = self.OVERRIDE_INTERNAL + + # pim neighbors + self._had_neighbors = False + self.neighbors = {} + self.neighbors_lock = RWLockWrite() + self.interface_logger = logging.LoggerAdapter(InterfacePim.LOGGER, {'vif': vif_index, 'interfacename': interface_name}) + + # SOCKET + s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_PIM) + + ip_interface = "" + for if_addr in netifaces.ifaddresses(interface_name)[netifaces.AF_INET6]: + ip_interface = if_addr["addr"] + if ipaddress.IPv6Address(if_addr['addr'].split("%")[0]).is_link_local: + ip_interface = if_addr['addr'].split("%")[0] + # bind to interface + s.bind(socket.getaddrinfo(if_addr['addr'], None, 0, socket.SOCK_RAW, 0, socket.AI_PASSIVE)[0][4]) + break + + self.ip_interface = ip_interface + + # allow other sockets to bind this port too + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # explicitly join the multicast group on the interface specified + if_index = if_nametoindex(interface_name) + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, + socket.inet_pton(socket.AF_INET6, InterfacePim6.MCAST_GRP) + struct.pack('@I', if_index)) + s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8')) + + # set socket output interface + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, struct.pack('@I', if_index)) + + # set socket TTL to 1 + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 1) + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_UNICAST_HOPS, 1) + + # don't receive outgoing packets + s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, 0) + Interface.__init__(self, interface_name, s, s, vif_index) + Interface._enable(self) + self.force_send_hello() + + @staticmethod + def get_kernel(): + return Main.kernel_v6 + + def send(self, data: bytes, group_ip: str=MCAST_GRP): + super().send(data=data, group_ip=group_ip) + + def _receive(self, raw_bytes, ancdata, src_addr): + if raw_bytes: + packet = ReceivedPacket_v6(raw_bytes, ancdata, src_addr, 103, self) + self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet) diff --git a/pimdm/Kernel.py b/pimdm/Kernel.py index 2e7014f..ccefe41 100644 --- a/pimdm/Kernel.py +++ b/pimdm/Kernel.py @@ -1,72 +1,38 @@ import socket import struct -from threading import RLock, Thread -import traceback - import ipaddress +import traceback +from socket import if_nametoindex +from threading import RLock, Thread +from abc import abstractmethod, ABCMeta -from pimdm.RWLock.RWLock import RWLockWrite +from pimdm import UnicastRouting, Main +from pimdm.rwlock.RWLock import RWLockWrite -from pimdm.InterfacePIM import InterfacePim +from pimdm.InterfaceMLD import InterfaceMLD from pimdm.InterfaceIGMP import InterfaceIGMP +from pimdm.InterfacePIM import InterfacePim +from pimdm.InterfacePIM6 import InterfacePim6 from pimdm.tree.KernelEntry import KernelEntry -from pimdm import UnicastRouting, Main - -class Kernel: - # MRT - MRT_BASE = 200 - MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */ - MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */ - MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */ - MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */ - MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */ - MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */ - MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */ - MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */ - MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */ - MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */ - #MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */ - #MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */ - #MRT_MAX = (MRT_BASE + 11) +from pimdm.tree.KernelEntryInterface import KernelEntry4Interface, KernelEntry6Interface +class Kernel(metaclass=ABCMeta): # Max Number of Virtual Interfaces MAXVIFS = 32 - # SIGNAL MSG TYPE - IGMPMSG_NOCACHE = 1 - IGMPMSG_WRONGVIF = 2 - IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM - - - # Interface flags - VIFF_TUNNEL = 0x1 # IPIP tunnel - VIFF_SRCRT = 0x2 # NI - VIFF_REGISTER = 0x4 # register vif - VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface - - def __init__(self): + def __init__(self, kernel_socket): # Kernel is running self.running = True # KEY : interface_ip, VALUE : vif_index - self.vif_dic = {} - self.vif_index_to_name_dic = {} - self.vif_name_to_index_dic = {} + self.vif_index_to_name_dic = {} # KEY : vif_index, VALUE : interface_name + self.vif_name_to_index_dic = {} # KEY : interface_name, VALUE : vif_index # KEY : source_ip, VALUE : {group_ip: KernelEntry} self.routing = {} - s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP) - - # MRT INIT - s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_INIT, 1) - - # MRT PIM - s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_PIM, 0) - s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ASSERT, 1) - - self.socket = s + self.socket = kernel_socket self.rwlock = RWLockWrite() self.interface_lock = RLock() @@ -74,9 +40,9 @@ def __init__(self): # todo useless in PIM-DM... useful in PIM-SM #self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER) - - self.pim_interface = {} # name: interface_pim - self.igmp_interface = {} # name: interface_igmp + # interfaces being monitored by this process + self.pim_interface = {} # name: interface_pim + self.membership_interface = {} # name: interface_igmp or interface_mld # logs self.interface_logger = Main.logger.getChild('KernelInterface') @@ -101,55 +67,43 @@ def __init__(self): struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */ }; ''' + @abstractmethod def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0): - if type(ip_interface) is str: - ip_interface = socket.inet_aton(ip_interface) - - struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface, - socket.inet_aton("0.0.0.0")) - with self.rwlock.genWlock(): - self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif) - self.vif_dic[socket.inet_ntoa(ip_interface)] = index - self.vif_index_to_name_dic[index] = interface_name - self.vif_name_to_index_dic[interface_name] = index - - for source_dict in list(self.routing.values()): - for kernel_entry in list(source_dict.values()): - kernel_entry.new_interface(index) - - self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index) - return index - + raise NotImplementedError def create_pim_interface(self, interface_name: str, state_refresh_capable:bool): with self.interface_lock: pim_interface = self.pim_interface.get(interface_name) - igmp_interface = self.igmp_interface.get(interface_name) - vif_already_exists = pim_interface or igmp_interface + membership_interface = self.membership_interface.get(interface_name) + vif_already_exists = pim_interface or membership_interface if pim_interface: # already exists pim_interface.set_state_refresh_capable(state_refresh_capable) return - elif igmp_interface: - index = igmp_interface.vif_index + elif membership_interface: + index = membership_interface.vif_index else: index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0] ip_interface = None if interface_name not in self.pim_interface: - pim_interface = InterfacePim(interface_name, index, state_refresh_capable) + pim_interface = self._create_pim_interface_object(interface_name, index, state_refresh_capable) self.pim_interface[interface_name] = pim_interface ip_interface = pim_interface.ip_interface if not vif_already_exists: self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index) - def create_igmp_interface(self, interface_name: str): + @abstractmethod + def _create_pim_interface_object(self, interface_name, index, state_refresh_capable): + raise NotImplementedError + + def create_membership_interface(self, interface_name: str): with self.interface_lock: pim_interface = self.pim_interface.get(interface_name) - igmp_interface = self.igmp_interface.get(interface_name) - vif_already_exists = pim_interface or igmp_interface - if igmp_interface: + membership_interface = self.membership_interface.get(interface_name) + vif_already_exists = pim_interface or membership_interface + if membership_interface: # already exists return elif pim_interface: @@ -158,47 +112,222 @@ def create_igmp_interface(self, interface_name: str): index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0] ip_interface = None - if interface_name not in self.igmp_interface: - igmp_interface = InterfaceIGMP(interface_name, index) - self.igmp_interface[interface_name] = igmp_interface + if interface_name not in self.membership_interface: + igmp_interface = self._create_membership_interface_object(interface_name, index) + self.membership_interface[interface_name] = igmp_interface ip_interface = igmp_interface.ip_interface if not vif_already_exists: self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index) + @abstractmethod + def _create_membership_interface_object(self, interface_name, index): + raise NotImplementedError - def remove_interface(self, interface_name, igmp:bool=False, pim:bool=False): + def remove_interface(self, interface_name, membership: bool = False, pim: bool = False): with self.interface_lock: - ip_interface = None pim_interface = self.pim_interface.get(interface_name) - igmp_interface = self.igmp_interface.get(interface_name) - if (igmp and not igmp_interface) or (pim and not pim_interface) or (not igmp and not pim): + membership_interface = self.membership_interface.get(interface_name) + if (membership and not membership_interface) or (pim and not pim_interface) or (not membership and not pim): return if pim: pim_interface = self.pim_interface.pop(interface_name) - ip_interface = pim_interface.ip_interface pim_interface.remove() - elif igmp: - igmp_interface = self.igmp_interface.pop(interface_name) - ip_interface = igmp_interface.ip_interface - igmp_interface.remove() + elif membership: + membership_interface = self.membership_interface.pop(interface_name) + membership_interface.remove() + + if not self.membership_interface.get(interface_name) and not self.pim_interface.get(interface_name): + self.remove_virtual_interface(interface_name) + + @abstractmethod + def remove_virtual_interface(self, interface_name): + raise NotImplementedError + + ############################################# + # Manipulate multicast routing table + ############################################# + @abstractmethod + def set_multicast_route(self, kernel_entry: KernelEntry): + raise NotImplementedError + + @abstractmethod + def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index): + raise NotImplementedError + + @abstractmethod + def remove_multicast_route(self, kernel_entry: KernelEntry): + raise NotImplementedError + + @abstractmethod + def exit(self): + raise NotImplementedError + + @abstractmethod + def handler(self): + raise NotImplementedError + + def get_routing_entry(self, source_group: tuple, create_if_not_existent=True): + ip_src = source_group[0] + ip_dst = source_group[1] + with self.rwlock.genRlock(): + if ip_src in self.routing and ip_dst in self.routing[ip_src]: + return self.routing[ip_src][ip_dst] + + with self.rwlock.genWlock(): + if ip_src in self.routing and ip_dst in self.routing[ip_src]: + return self.routing[ip_src][ip_dst] + elif create_if_not_existent: + kernel_entry = KernelEntry(ip_src, ip_dst, self._get_kernel_entry_interface()) + if ip_src not in self.routing: + self.routing[ip_src] = {} + + iif = UnicastRouting.check_rpf(ip_src) + self.set_flood_multicast_route(ip_src, ip_dst, iif) + self.routing[ip_src][ip_dst] = kernel_entry + return kernel_entry + else: + return None + + @staticmethod + @abstractmethod + def _get_kernel_entry_interface(): + pass + + # notify KernelEntries about changes at the unicast routing table + def notify_unicast_changes(self, subnet): + with self.rwlock.genWlock(): + for source_ip in list(self.routing.keys()): + source_ip_obj = ipaddress.ip_address(source_ip) + if source_ip_obj not in subnet: + continue + for group_ip in list(self.routing[source_ip].keys()): + self.routing[source_ip][group_ip].network_update() + + + # notify about changes at the interface (IP) + ''' + def notify_interface_change(self, interface_name): + with self.interface_lock: + # check if interface was already added + if interface_name not in self.vif_name_to_index_dic: + return + + print("trying to change ip") + pim_interface = self.pim_interface.get(interface_name) + if pim_interface: + old_ip = pim_interface.get_ip() + pim_interface.change_interface() + new_ip = pim_interface.get_ip() + if old_ip != new_ip: + self.vif_dic[new_ip] = self.vif_dic.pop(old_ip) + + igmp_interface = self.igmp_interface.get(interface_name) + if igmp_interface: + igmp_interface.change_interface() + ''' + + # When interface changes number of neighbors verify if olist changes and prune/forward respectively + def interface_change_number_of_neighbors(self): + with self.rwlock.genRlock(): + for groups_dict in self.routing.values(): + for entry in groups_dict.values(): + entry.change_at_number_of_neighbors() + + # When new neighbor connects try to resend last state refresh msg (if AssertWinner) + def new_or_reset_neighbor(self, vif_index, neighbor_ip): + with self.rwlock.genRlock(): + for groups_dict in self.routing.values(): + for entry in groups_dict.values(): + entry.new_or_reset_neighbor(vif_index, neighbor_ip) + + +class Kernel4(Kernel): + # MRT + MRT_BASE = 200 + MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */ + MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */ + MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */ + MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */ + MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */ + MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */ + MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */ + MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */ + MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */ + MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */ + #MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */ + #MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */ + #MRT_MAX = (MRT_BASE + 11) + + # Max Number of Virtual Interfaces + MAXVIFS = 32 + + # SIGNAL MSG TYPE + IGMPMSG_NOCACHE = 1 + IGMPMSG_WRONGVIF = 2 + IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM + + # Interface flags + VIFF_TUNNEL = 0x1 # IPIP tunnel + VIFF_SRCRT = 0x2 # NI + VIFF_REGISTER = 0x4 # register vif + VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface - if (not self.igmp_interface.get(interface_name) and not self.pim_interface.get(interface_name)): - self.remove_virtual_interface(ip_interface) + def __init__(self): + s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP) + + # MRT INIT + s.setsockopt(socket.IPPROTO_IP, self.MRT_INIT, 1) + + # MRT PIM + s.setsockopt(socket.IPPROTO_IP, self.MRT_PIM, 0) + s.setsockopt(socket.IPPROTO_IP, self.MRT_ASSERT, 1) + + super().__init__(s) + + ''' + Structure to create/remove virtual interfaces + struct vifctl { + vifi_t vifc_vifi; /* Index of VIF */ + unsigned char vifc_flags; /* VIFF_ flags */ + unsigned char vifc_threshold; /* ttl limit */ + unsigned int vifc_rate_limit; /* Rate limiter values (NI) */ + union { + struct in_addr vifc_lcl_addr; /* Local interface address */ + int vifc_lcl_ifindex; /* Local interface index */ + }; + struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */ + }; + ''' + def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0): + if type(ip_interface) is str: + ip_interface = socket.inet_aton(ip_interface) + + struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface, + socket.inet_aton("0.0.0.0")) + with self.rwlock.genWlock(): + self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_VIF, struct_mrt_add_vif) + self.vif_index_to_name_dic[index] = interface_name + self.vif_name_to_index_dic[interface_name] = index + for source_dict in list(self.routing.values()): + for kernel_entry in list(source_dict.values()): + kernel_entry.new_interface(index) - def remove_virtual_interface(self, ip_interface): + self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index) + return index + + def remove_virtual_interface(self, interface_name): #with self.interface_lock: - index = self.vif_dic[ip_interface] + index = self.vif_name_to_index_dic.pop(interface_name, None) struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0")) - self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_VIF, struct_vifctl) + self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DEL_VIF, struct_vifctl) - del self.vif_dic[ip_interface] del self.vif_name_to_index_dic[self.vif_index_to_name_dic[index]] interface_name = self.vif_index_to_name_dic.pop(index) - # alterar MFC's para colocar a 0 esta interface + # change MFC's to not forward traffic by this interface (set OIL to 0 for this interface) with self.rwlock.genWlock(): for source_dict in list(self.routing.values()): for kernel_entry in list(source_dict.values()): @@ -235,7 +364,7 @@ def set_multicast_route(self, kernel_entry: KernelEntry): #outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5 #struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, kernel_entry.inbound_interface_index, *outbound_interfaces_and_other_parameters) - self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl) + self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_MFC, struct_mfcctl) def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index): source_ip = socket.inet_aton(source_ip) @@ -250,7 +379,7 @@ def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index #outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5 #struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters) - self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl) + self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_ADD_MFC, struct_mfcctl) def remove_multicast_route(self, kernel_entry: KernelEntry): source_ip = socket.inet_aton(kernel_entry.source_ip) @@ -258,7 +387,7 @@ def remove_multicast_route(self, kernel_entry: KernelEntry): outbound_interfaces_and_other_parameters = [0] + [0]*Kernel.MAXVIFS + [0]*4 struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters) - self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl) + self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DEL_MFC, struct_mfcctl) self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip) if len(self.routing[kernel_entry.source_ip]) == 0: self.routing.pop(kernel_entry.source_ip) @@ -267,7 +396,7 @@ def exit(self): self.running = False # MRT DONE - self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DONE, 1) + self.socket.setsockopt(socket.IPPROTO_IP, self.MRT_DONE, 1) self.socket.close() @@ -304,10 +433,10 @@ def handler(self): ip_src = socket.inet_ntoa(im_src) ip_dst = socket.inet_ntoa(im_dst) - if im_msgtype == Kernel.IGMPMSG_NOCACHE: + if im_msgtype == self.IGMPMSG_NOCACHE: print("IGMP NO CACHE") self.igmpmsg_nocache_handler(ip_src, ip_dst, im_vif) - elif im_msgtype == Kernel.IGMPMSG_WRONGVIF: + elif im_msgtype == self.IGMPMSG_WRONGVIF: print("WRONG VIF HANDLER") self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif) #elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT: @@ -338,73 +467,266 @@ def igmpmsg_wholepacket_handler(self, ip_src, ip_dst): #kernel_entry.recv_data_msg(iif) ''' + @staticmethod + def _get_kernel_entry_interface(): + return KernelEntry4Interface + + def _create_pim_interface_object(self, interface_name, index, state_refresh_capable): + return InterfacePim(interface_name, index, state_refresh_capable) + + def _create_membership_interface_object(self, interface_name, index): + return InterfaceIGMP(interface_name, index) + + +class Kernel6(Kernel): + # MRT6 + MRT6_BASE = 200 + MRT6_INIT = (MRT6_BASE) # /* Activate the kernel mroute code */ + MRT6_DONE = (MRT6_BASE + 1) # /* Shutdown the kernel mroute */ + MRT6_ADD_MIF = (MRT6_BASE + 2) # /* Add a virtual interface */ + MRT6_DEL_MIF = (MRT6_BASE + 3) # /* Delete a virtual interface */ + MRT6_ADD_MFC = (MRT6_BASE + 4) # /* Add a multicast forwarding entry */ + MRT6_DEL_MFC = (MRT6_BASE + 5) # /* Delete a multicast forwarding entry */ + MRT6_VERSION = (MRT6_BASE + 6) # /* Get the kernel multicast version */ + MRT6_ASSERT = (MRT6_BASE + 7) # /* Activate PIM assert mode */ + MRT6_PIM = (MRT6_BASE + 8) # /* enable PIM code */ + MRT6_TABLE = (MRT6_BASE + 9) # /* Specify mroute table ID */ + MRT6_ADD_MFC_PROXY = (MRT6_BASE + 10) # /* Add a (*,*|G) mfc entry */ + MRT6_DEL_MFC_PROXY = (MRT6_BASE + 11) # /* Del a (*,*|G) mfc entry */ + MRT6_MAX = (MRT6_BASE + 11) + + # define SIOCGETMIFCNT_IN6 SIOCPROTOPRIVATE /* IP protocol privates */ + # define SIOCGETSGCNT_IN6 (SIOCPROTOPRIVATE+1) + # define SIOCGETRPF (SIOCPROTOPRIVATE+2) + # Max Number of Virtual Interfaces + MAXVIFS = 32 - def get_routing_entry(self, source_group: tuple, create_if_not_existent=True): - ip_src = source_group[0] - ip_dst = source_group[1] - with self.rwlock.genRlock(): - if ip_src in self.routing and ip_dst in self.routing[ip_src]: - return self.routing[ip_src][ip_dst] + # SIGNAL MSG TYPE + MRT6MSG_NOCACHE = 1 + MRT6MSG_WRONGMIF = 2 + MRT6MSG_WHOLEPKT = 3 # /* used for use level encap */ + + # Interface flags + MIFF_REGISTER = 0x1 # /* register vif */ + + def __init__(self): + s = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6) + + # MRT INIT + s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_INIT, 1) + + # MRT PIM + s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_PIM, 0) + s.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ASSERT, 1) + super().__init__(s) + + ''' + Structure to create/remove multicast interfaces + struct mif6ctl { + mifi_t mif6c_mifi; /* Index of MIF */ + unsigned char mif6c_flags; /* MIFF_ flags */ + unsigned char vifc_threshold; /* ttl limit */ + __u16 mif6c_pifi; /* the index of the physical IF */ + unsigned int vifc_rate_limit; /* Rate limiter values (NI) */ + }; + ''' + def create_virtual_interface(self, ip_interface, interface_name: str, index, flags=0x0): + physical_if_index = if_nametoindex(interface_name) + struct_mrt_add_vif = struct.pack("HBBHI", index, flags, 1, physical_if_index, 0) with self.rwlock.genWlock(): - if ip_src in self.routing and ip_dst in self.routing[ip_src]: - return self.routing[ip_src][ip_dst] - elif create_if_not_existent: - kernel_entry = KernelEntry(ip_src, ip_dst) - if ip_src not in self.routing: - self.routing[ip_src] = {} + self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MIF, struct_mrt_add_vif) + self.vif_index_to_name_dic[index] = interface_name + self.vif_name_to_index_dic[interface_name] = index - iif = UnicastRouting.check_rpf(ip_src) - self.set_flood_multicast_route(ip_src, ip_dst, iif) - self.routing[ip_src][ip_dst] = kernel_entry - return kernel_entry - else: - return None + for source_dict in list(self.routing.values()): + for kernel_entry in list(source_dict.values()): + kernel_entry.new_interface(index) - # notify KernelEntries about changes at the unicast routing table - def notify_unicast_changes(self, subnet): + self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index) + return index + + def remove_virtual_interface(self, interface_name): + # with self.interface_lock: + mif_index = self.vif_name_to_index_dic.pop(interface_name, None) + interface_name = self.vif_index_to_name_dic.pop(mif_index) + + physical_if_index = if_nametoindex(interface_name) + struct_vifctl = struct.pack("HBBHI", mif_index, 0, 0, physical_if_index, 0) + self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DEL_MIF, struct_vifctl) + + # alterar MFC's para colocar a 0 esta interface with self.rwlock.genWlock(): - for source_ip in list(self.routing.keys()): - source_ip_obj = ipaddress.ip_address(source_ip) - if source_ip_obj not in subnet: - continue - for group_ip in list(self.routing[source_ip].keys()): - self.routing[source_ip][group_ip].network_update() + for source_dict in list(self.routing.values()): + for kernel_entry in list(source_dict.values()): + kernel_entry.remove_interface(mif_index) + self.interface_logger.debug('Remove virtual interface: %s -> %d', interface_name, mif_index) - # notify about changes at the interface (IP) ''' - def notify_interface_change(self, interface_name): - with self.interface_lock: - # check if interface was already added - if interface_name not in self.vif_name_to_index_dic: - return + /* Cache manipulation structures for mrouted and PIMd */ + typedef __u32 if_mask; + typedef struct if_set { + if_mask ifs_bits[__KERNEL_DIV_ROUND_UP(IF_SETSIZE, NIFBITS)]; + } if_set; + + struct mf6cctl { + struct sockaddr_in6 mf6cc_origin; /* Origin of mcast */ + struct sockaddr_in6 mf6cc_mcastgrp; /* Group in question */ + mifi_t mf6cc_parent; /* Where it arrived */ + struct if_set mf6cc_ifset; /* Where it is going */ + }; + ''' + def set_multicast_route(self, kernel_entry: KernelEntry): + source_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.source_ip) + sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0) + group_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.group_ip) + sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0) - print("trying to change ip") - pim_interface = self.pim_interface.get(interface_name) - if pim_interface: - old_ip = pim_interface.get_ip() - pim_interface.change_interface() - new_ip = pim_interface.get_ip() - if old_ip != new_ip: - self.vif_dic[new_ip] = self.vif_dic.pop(old_ip) + outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes() + if len(outbound_interfaces) != 8: + raise Exception - igmp_interface = self.igmp_interface.get(interface_name) - if igmp_interface: - igmp_interface.change_interface() + # outbound_interfaces_and_other_parameters = list(kernel_entry.outbound_interfaces) + [0]*4 + # outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4 + outgoing_interface_list = outbound_interfaces + + # outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5 + # struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) + struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group, + kernel_entry.inbound_interface_index, + *outgoing_interface_list) + self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MFC, struct_mf6cctl) + + def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index): + source_ip = socket.inet_pton(socket.AF_INET6, source_ip) + sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0) + group_ip = socket.inet_pton(socket.AF_INET6, group_ip) + sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0) + + outbound_interfaces = [255] * 8 + outbound_interfaces[inbound_interface_index // 32] = 0xFFFFFFFF & ~(1 << (inbound_interface_index % 32)) + + # outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5 + # struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) + # struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters) + struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group, + inbound_interface_index, *outbound_interfaces) + self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_ADD_MFC, struct_mf6cctl) + + def remove_multicast_route(self, kernel_entry: KernelEntry): + source_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.source_ip) + sockaddr_in6_source = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, source_ip, 0) + group_ip = socket.inet_pton(socket.AF_INET6, kernel_entry.group_ip) + sockaddr_in6_group = struct.pack("H H I 16s I", socket.AF_INET6, 0, 0, group_ip, 0) + outbound_interfaces = [0] * 8 + + # struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters) + struct_mf6cctl = struct.pack("28s 28s H " + "I" * 8, sockaddr_in6_source, sockaddr_in6_group, 0, + *outbound_interfaces) + + self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DEL_MFC, struct_mf6cctl) + self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip) + if len(self.routing[kernel_entry.source_ip]) == 0: + self.routing.pop(kernel_entry.source_ip) + + def exit(self): + self.running = False + + # MRT DONE + self.socket.setsockopt(socket.IPPROTO_IPV6, self.MRT6_DONE, 1) + self.socket.close() + + ''' + /* + * Structure used to communicate from kernel to multicast router. + * We'll overlay the structure onto an MLD header (not an IPv6 heder like igmpmsg{} + * used for IPv4 implementation). This is because this structure will be passed via an + * IPv6 raw socket, on which an application will only receiver the payload i.e the data after + * the IPv6 header and all the extension headers. (See section 3 of RFC 3542) + */ + + struct mrt6msg { + __u8 im6_mbz; /* must be zero */ + __u8 im6_msgtype; /* what type of message */ + __u16 im6_mif; /* mif rec'd on */ + __u32 im6_pad; /* padding for 64 bit arch */ + struct in6_addr im6_src, im6_dst; + }; + + /* ip6mr netlink cache report attributes */ + enum { + IP6MRA_CREPORT_UNSPEC, + IP6MRA_CREPORT_MSGTYPE, + IP6MRA_CREPORT_MIF_ID, + IP6MRA_CREPORT_SRC_ADDR, + IP6MRA_CREPORT_DST_ADDR, + IP6MRA_CREPORT_PKT, + __IP6MRA_CREPORT_MAX + }; ''' + def handler(self): + while self.running: + try: + msg = self.socket.recv(500) + if len(msg) < 40: + continue + (im6_mbz, im6_msgtype, im6_mif, _, im6_src, im6_dst) = struct.unpack("B B H I 16s 16s", msg[:40]) + # print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst))) - # When interface changes number of neighbors verify if olist changes and prune/forward respectively - def interface_change_number_of_neighbors(self): - with self.rwlock.genRlock(): - for groups_dict in self.routing.values(): - for entry in groups_dict.values(): - entry.change_at_number_of_neighbors() + if im6_mbz != 0: + continue - # When new neighbor connects try to resend last state refresh msg (if AssertWinner) - def new_or_reset_neighbor(self, vif_index, neighbor_ip): - with self.rwlock.genRlock(): - for groups_dict in self.routing.values(): - for entry in groups_dict.values(): - entry.new_or_reset_neighbor(vif_index, neighbor_ip) + print(im6_mbz) + print(im6_msgtype) + print(im6_mif) + print(socket.inet_ntop(socket.AF_INET6, im6_src)) + print(socket.inet_ntop(socket.AF_INET6, im6_dst)) + # print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst))) + + ip_src = socket.inet_ntop(socket.AF_INET6, im6_src) + ip_dst = socket.inet_ntop(socket.AF_INET6, im6_dst) + + if im6_msgtype == self.MRT6MSG_NOCACHE: + print("MRT6 NO CACHE") + self.msg_nocache_handler(ip_src, ip_dst, im6_mif) + elif im6_msgtype == self.MRT6MSG_WRONGMIF: + print("WRONG MIF HANDLER") + self.msg_wrongvif_handler(ip_src, ip_dst, im6_mif) + # elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT: + # print("IGMP_WHOLEPKT") + # self.igmpmsg_wholepacket_handler(ip_src, ip_dst) + else: + raise Exception + except Exception: + traceback.print_exc() + continue + + # receive multicast (S,G) packet and multicast routing table has no (S,G) entry + def msg_nocache_handler(self, ip_src, ip_dst, iif): + source_group_pair = (ip_src, ip_dst) + self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif) + + # receive multicast (S,G) packet in a outbound_interface + def msg_wrongvif_handler(self, ip_src, ip_dst, iif): + source_group_pair = (ip_src, ip_dst) + self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif) + + ''' useless in PIM-DM... useful in PIM-SM + def msg_wholepacket_handler(self, ip_src, ip_dst): + #kernel_entry = self.routing[(ip_src, ip_dst)] + source_group_pair = (ip_src, ip_dst) + self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg() + #kernel_entry.recv_data_msg(iif) + ''' + + @staticmethod + def _get_kernel_entry_interface(): + return KernelEntry6Interface + + def _create_pim_interface_object(self, interface_name, index, state_refresh_capable): + return InterfacePim6(interface_name, index, state_refresh_capable) + + def _create_membership_interface_object(self, interface_name, index): + return InterfaceMLD(interface_name, index) diff --git a/pimdm/Main.py b/pimdm/Main.py index 7597c19..5308185 100644 --- a/pimdm/Main.py +++ b/pimdm/Main.py @@ -1,69 +1,67 @@ import sys import time import netifaces -import logging, logging.handlers +import logging +import logging.handlers from prettytable import PrettyTable -from pimdm.TestLogger import RootFilter -from pimdm import UnicastRouting +from pimdm import UnicastRouting +from pimdm.TestLogger import RootFilter interfaces = {} # interfaces with multicast routing enabled igmp_interfaces = {} # igmp interfaces +interfaces_v6 = {} # pim v6 interfaces +mld_interfaces = {} # mld interfaces kernel = None +kernel_v6 = None unicast_routing = None logger = None -def add_pim_interface(interface_name, state_refresh_capable:bool=False): - kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable) - - -def add_igmp_interface(interface_name): - kernel.create_igmp_interface(interface_name=interface_name) - -''' -def add_interface(interface_name, pim=False, igmp=False): - #if pim is True and interface_name not in interfaces: - # interface = InterfacePim(interface_name) - # interfaces[interface_name] = interface - # interface.create_virtual_interface() - #if igmp is True and interface_name not in igmp_interfaces: - # interface = InterfaceIGMP(interface_name) - # igmp_interfaces[interface_name] = interface - kernel.create_interface(interface_name=interface_name, pim=pim, igmp=igmp) - #if pim: - # interfaces[interface_name] = kernel.pim_interface[interface_name] - #if igmp: - # igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name] -''' - -def remove_interface(interface_name, pim=False, igmp=False): - #if pim is True and ((interface_name in interfaces) or interface_name == "*"): - # if interface_name == "*": - # interface_name_list = list(interfaces.keys()) - # else: - # interface_name_list = [interface_name] - # for if_name in interface_name_list: - # interface_obj = interfaces.pop(if_name) - # interface_obj.remove() - # #interfaces[if_name].remove() - # #del interfaces[if_name] - # print("removido interface") - # print(interfaces) - - #if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"): - # if interface_name == "*": - # interface_name_list = list(igmp_interfaces.keys()) - # else: - # interface_name_list = [interface_name] - # for if_name in interface_name_list: - # igmp_interfaces[if_name].remove() - # del igmp_interfaces[if_name] - # print("removido interface") - # print(igmp_interfaces) - kernel.remove_interface(interface_name, pim=pim, igmp=igmp) - -def list_neighbors(): - interfaces_list = interfaces.values() + +def add_pim_interface(interface_name, state_refresh_capable: bool = False, ipv4=True, ipv6=False): + if interface_name == "*": + for interface_name in netifaces.interfaces(): + add_pim_interface(interface_name, ipv4, ipv6) + return + + if ipv4 and kernel is not None: + kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable) + if ipv6 and kernel_v6 is not None: + kernel_v6.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable) + + +def add_membership_interface(interface_name, ipv4=True, ipv6=False): + if interface_name == "*": + for interface_name in netifaces.interfaces(): + add_membership_interface(interface_name, ipv4, ipv6) + return + + if ipv4 and kernel is not None: + kernel.create_membership_interface(interface_name=interface_name) + if ipv6 and kernel_v6 is not None: + kernel_v6.create_membership_interface(interface_name=interface_name) + + +def remove_interface(interface_name, pim=False, membership=False, ipv4=True, ipv6=False): + if interface_name == "*": + for interface_name in netifaces.interfaces(): + remove_interface(interface_name, pim, membership, ipv4, ipv6) + return + + if ipv4 and kernel is not None: + kernel.remove_interface(interface_name, pim=pim, membership=membership) + if ipv6 and kernel_v6 is not None: + kernel_v6.remove_interface(interface_name, pim=pim, membership=membership) + + +def list_neighbors(ipv4=False, ipv6=False): + if ipv4: + interfaces_list = interfaces.values() + elif ipv6: + interfaces_list = interfaces_v6.values() + else: + return "Unknown IP family" + t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"]) check_time = time.time() for interface in interfaces_list: @@ -76,38 +74,62 @@ def list_neighbors(): print(t) return str(t) -def list_enabled_interfaces(): - global interfaces - t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'State Refresh Enabled', 'IGMP State']) +def list_enabled_interfaces(ipv4=False, ipv6=False): + if ipv4: + t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'State Refresh Enabled', 'IGMP State']) + family = netifaces.AF_INET + pim_interfaces = interfaces + membership_interfaces = igmp_interfaces + elif ipv6: + t = PrettyTable(['Interface', 'IP', 'PIM/MLD Enabled', 'State Refresh Enabled', 'MLD State']) + family = netifaces.AF_INET6 + pim_interfaces = interfaces_v6 + membership_interfaces = mld_interfaces + else: + return "Unknown IP family" + for interface in netifaces.interfaces(): try: # TODO: fix same interface with multiple ips - ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr'] - pim_enabled = interface in interfaces - igmp_enabled = interface in igmp_interfaces - enabled = str(pim_enabled) + "/" + str(igmp_enabled) + ip = netifaces.ifaddresses(interface)[family][0]['addr'] + pim_enabled = interface in pim_interfaces + membership_enabled = interface in membership_interfaces + enabled = str(pim_enabled) + "/" + str(membership_enabled) state_refresh_enabled = "-" if pim_enabled: - state_refresh_enabled = interfaces[interface].is_state_refresh_enabled() - igmp_state = "-" - if igmp_enabled: - igmp_state = igmp_interfaces[interface].interface_state.print_state() - t.add_row([interface, ip, enabled, state_refresh_enabled, igmp_state]) + state_refresh_enabled = pim_interfaces[interface].is_state_refresh_enabled() + membership_state = "-" + if membership_enabled: + membership_state = membership_interfaces[interface].interface_state.print_state() + t.add_row([interface, ip, enabled, state_refresh_enabled, membership_state]) except Exception: continue print(t) return str(t) -def list_state(): - state_text = "IGMP State:\n" + list_igmp_state() + "\n\n\n\n" + "Multicast Routing State:\n" + list_routing_state() - return state_text +def list_state(ipv4=True, ipv6=False): + state_text = "" + if ipv4: + state_text = "IGMP State:\n{}\n\n\n\nMulticast Routing State:\n{}" + elif ipv6: + state_text = "MLD State:\n{}\n\n\n\nMulticast Routing State:\n{}" + else: + return state_text + return state_text.format(list_membership_state(ipv4, ipv6), list_routing_state(ipv4, ipv6)) -def list_igmp_state(): +def list_membership_state(ipv4=True, ipv6=False): t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState']) - for (interface_name, interface_obj) in list(igmp_interfaces.items()): + if ipv4: + membership_interfaces = igmp_interfaces + elif ipv6: + membership_interfaces = mld_interfaces + else: + membership_interfaces = {} + + for (interface_name, interface_obj) in list(membership_interfaces.items()): interface_state = interface_obj.interface_state state_txt = interface_state.print_state() print(interface_state.group_state.items()) @@ -119,12 +141,22 @@ def list_igmp_state(): return str(t) -def list_routing_state(): +def list_routing_state(ipv4=False, ipv6=False): + if ipv4: + routes = kernel.routing.values() + vif_indexes = kernel.vif_index_to_name_dic.keys() + dict_index_to_name = kernel.vif_index_to_name_dic + elif ipv6: + routes = kernel_v6.routing.values() + vif_indexes = kernel_v6.vif_index_to_name_dic.keys() + dict_index_to_name = kernel_v6.vif_index_to_name_dic + else: + raise Exception("Unknown IP family") + routing_entries = [] - for a in list(kernel.routing.values()): + for a in list(routes): for b in list(a.values()): routing_entries.append(b) - vif_indexes = kernel.vif_index_to_name_dic.keys() t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', 'LocalMembership', "Is Forwarding?"]) for entry in routing_entries: @@ -134,7 +166,7 @@ def list_routing_state(): for index in vif_indexes: interface_state = entry.interface_state[index] - interface_name = kernel.vif_index_to_name_dic[index] + interface_name = dict_index_to_name[index] local_membership = type(interface_state._local_membership_state).__name__ try: assert_state = type(interface_state._assert_state).__name__ @@ -154,8 +186,11 @@ def list_routing_state(): def stop(): - remove_interface("*", pim=True, igmp=True) - kernel.exit() + remove_interface("*", pim=True, membership=True, ipv4=True, ipv6=True) + if kernel is not None: + kernel.exit() + if kernel_v6 is not None: + kernel_v6.exit() unicast_routing.stop() @@ -169,6 +204,22 @@ def test(router_name, server_logger_ip): logger.addHandler(socketHandler) +def enable_ipv6_kernel(): + """ + Function to explicitly enable IPv6 Multicast Routing stack. + This may not be enabled by default due to some old linux kernels that may not have IPv6 stack or do not have + IPv6 multicast routing support + """ + global kernel_v6 + from pimdm.Kernel import Kernel6 + kernel_v6 = Kernel6() + + global interfaces_v6 + global mld_interfaces + interfaces_v6 = kernel_v6.pim_interface + mld_interfaces = kernel_v6.membership_interface + + def main(): # logging global logger @@ -177,8 +228,8 @@ def main(): logger.addHandler(logging.StreamHandler(sys.stdout)) global kernel - from pimdm.Kernel import Kernel - kernel = Kernel() + from pimdm.Kernel import Kernel4 + kernel = Kernel4() global unicast_routing unicast_routing = UnicastRouting.UnicastRouting() @@ -186,4 +237,9 @@ def main(): global interfaces global igmp_interfaces interfaces = kernel.pim_interface - igmp_interfaces = kernel.igmp_interface + igmp_interfaces = kernel.membership_interface + + try: + enable_ipv6_kernel() + except: + pass diff --git a/pimdm/Neighbor.py b/pimdm/Neighbor.py index f2d67cf..f01c43a 100644 --- a/pimdm/Neighbor.py +++ b/pimdm/Neighbor.py @@ -1,8 +1,10 @@ -from threading import Timer import time -from pimdm.utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING -from threading import Lock, RLock import logging +from threading import Timer +from threading import Lock, RLock + +from pimdm.tree.globals import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT +from pimdm.utils import TYPE_CHECKING if TYPE_CHECKING: from pimdm.InterfacePIM import InterfacePim @@ -10,7 +12,6 @@ class Neighbor: LOGGER = logging.getLogger('pim.Interface.Neighbor') - def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int, state_refresh_capable: bool): if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT: @@ -37,7 +38,6 @@ def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, he self.tree_interface_nlt_subscribers = [] self.tree_interface_nlt_subscribers_lock = RLock() - def set_hello_hold_time(self, hello_hold_time: int): self.hello_hold_time = hello_hold_time if self.neighbor_liveness_timer is not None: @@ -85,11 +85,9 @@ def remove(self): for tree_if in self.tree_interface_nlt_subscribers: tree_if.assert_winner_nlt_expires() - def reset(self): self.contact_interface.new_or_reset_neighbor(self.ip) - def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable): self.neighbor_logger.debug('Receive Hello message with HelloHoldTime: ' + str(hello_hold_time) + '; GenerationID: ' + str(generation_id) + '; StateRefreshCapable: ' + diff --git a/pimdm/Packet/PacketIpHeader.py b/pimdm/Packet/PacketIpHeader.py deleted file mode 100644 index 65a55ea..0000000 --- a/pimdm/Packet/PacketIpHeader.py +++ /dev/null @@ -1,61 +0,0 @@ -import struct -import socket - - -''' - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -|Version| IHL |Type of Service| Total Length | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Identification |Flags| Fragment Offset | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Time to Live | Protocol | Header Checksum | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Source Address | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Destination Address | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Options | Padding | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -''' -class PacketIpHeader: - IP_HDR = "! BBH HH BBH 4s 4s" - IP_HDR_LEN = struct.calcsize(IP_HDR) - - def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst): - self.version = ver - self.hdr_length = hdr_len - self.ttl = ttl - self.proto = proto - self.ip_src = ip_src - self.ip_dst = ip_dst - - def __len__(self): - return self.hdr_length - - @staticmethod - def parse_bytes(data: bytes): - (verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \ - struct.unpack(PacketIpHeader.IP_HDR, data) - - ver = (verhlen & 0xf0) >> 4 - hlen = (verhlen & 0x0f) * 4 - - ''' - "VER": ver, - "HLEN": hlen, - "TOS": tos, - "IPLEN": iplen, - "IPID": ipid, - "FRAG": frag, - "TTL": ttl, - "PROTO": proto, - "CKSUM": cksum, - "SRC": socket.inet_ntoa(src), - "DST": socket.inet_ntoa(dst) - ''' - - src_ip = socket.inet_ntoa(src) - dst_ip = socket.inet_ntoa(dst) - return PacketIpHeader(ver, hlen, ttl, proto, src_ip, dst_ip) diff --git a/pimdm/Packet/ReceivedPacket.py b/pimdm/Packet/ReceivedPacket.py deleted file mode 100644 index 2b5f34b..0000000 --- a/pimdm/Packet/ReceivedPacket.py +++ /dev/null @@ -1,25 +0,0 @@ -from .Packet import Packet -from .PacketIpHeader import PacketIpHeader -from .PacketIGMPHeader import PacketIGMPHeader -from .PacketPimHeader import PacketPimHeader -from pimdm.utils import TYPE_CHECKING -if TYPE_CHECKING: - from pimdm.Interface import Interface - - -class ReceivedPacket(Packet): - # choose payload protocol class based on ip protocol number - payload_protocol = {2: PacketIGMPHeader, 103: PacketPimHeader} - - def __init__(self, raw_packet: bytes, interface: 'Interface'): - self.interface = interface - # Parse ao packet e preencher objeto Packet - - packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN] - ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr) - protocol_number = ip_header.proto - - packet_without_ip_hdr = raw_packet[ip_header.hdr_length:] - payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr) - - super().__init__(ip_header=ip_header, payload=payload) diff --git a/pimdm/Run.py b/pimdm/Run.py index ebea1cc..74451b5 100644 --- a/pimdm/Run.py +++ b/pimdm/Run.py @@ -1,15 +1,17 @@ #!/usr/bin/env python3 -from pimdm.Daemon.Daemon import Daemon -from pimdm import Main -import _pickle as pickle -import socket -import sys import os +import sys +import socket import argparse import traceback +import _pickle as pickle + +from pimdm import Main +from pimdm.daemon.Daemon import Daemon + +VERSION = "1.1" -VERSION = "1.0.4.2" def client_socket(data_to_send): # Create a UDS socket @@ -58,26 +60,36 @@ def run(self): print(sys.stderr, 'sending data back to the client') print(pickle.loads(data)) args = pickle.loads(data) + if 'ipv4' not in args and 'ipv6' not in args or not (args.ipv4 or args.ipv6): + args.ipv4 = True + args.ipv6 = False + if 'list_interfaces' in args and args.list_interfaces: - connection.sendall(pickle.dumps(Main.list_enabled_interfaces())) + connection.sendall(pickle.dumps(Main.list_enabled_interfaces(ipv4=args.ipv4, ipv6=args.ipv6))) elif 'list_neighbors' in args and args.list_neighbors: - connection.sendall(pickle.dumps(Main.list_neighbors())) + connection.sendall(pickle.dumps(Main.list_neighbors(ipv4=args.ipv4, ipv6=args.ipv6))) elif 'list_state' in args and args.list_state: - connection.sendall(pickle.dumps(Main.list_state())) + connection.sendall(pickle.dumps(Main.list_state(ipv4=args.ipv4, ipv6=args.ipv6))) elif 'add_interface' in args and args.add_interface: - Main.add_pim_interface(args.add_interface[0], False) + Main.add_pim_interface(args.add_interface[0], False, ipv4=args.ipv4, ipv6=args.ipv6) connection.shutdown(socket.SHUT_RDWR) elif 'add_interface_sr' in args and args.add_interface_sr: - Main.add_pim_interface(args.add_interface_sr[0], True) + Main.add_pim_interface(args.add_interface_sr[0], True, ipv4=args.ipv4, ipv6=args.ipv6) connection.shutdown(socket.SHUT_RDWR) elif 'add_interface_igmp' in args and args.add_interface_igmp: - Main.add_igmp_interface(args.add_interface_igmp[0]) + Main.add_membership_interface(interface_name=args.add_interface_igmp[0], ipv4=True, ipv6=False) + connection.shutdown(socket.SHUT_RDWR) + elif 'add_interface_mld' in args and args.add_interface_mld: + Main.add_membership_interface(interface_name=args.add_interface_mld[0], ipv4=False, ipv6=True) connection.shutdown(socket.SHUT_RDWR) elif 'remove_interface' in args and args.remove_interface: - Main.remove_interface(args.remove_interface[0], pim=True) + Main.remove_interface(args.remove_interface[0], pim=True, ipv4=args.ipv4, ipv6=args.ipv6) connection.shutdown(socket.SHUT_RDWR) elif 'remove_interface_igmp' in args and args.remove_interface_igmp: - Main.remove_interface(args.remove_interface_igmp[0], igmp=True) + Main.remove_interface(args.remove_interface_igmp[0], membership=True, ipv4=True, ipv6=False) + connection.shutdown(socket.SHUT_RDWR) + elif 'remove_interface_mld' in args and args.remove_interface_mld: + Main.remove_interface(args.remove_interface_mld[0], membership=True, ipv4=False, ipv6=True) connection.shutdown(socket.SHUT_RDWR) elif 'stop' in args and args.stop: Main.stop() @@ -102,18 +114,30 @@ def main(): group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM") group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM") group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM") - group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces") - group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors") - group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP") - group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table") - group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface") - group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled") + group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces. " + "Use -4 or -6 to specify IPv4 or IPv6 interfaces.") + group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors. " + "Use -4 or -6 to specify IPv4 or IPv6 PIM neighbors.") + group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List IGMP/MLD and PIM-DM state machines." + " Use -4 or -6 to specify IPv4 or IPv6 state respectively.") + group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table. " + "Use -4 or -6 to specify IPv4 or IPv6 multicast routing table.") + group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface. " + "Use -4 or -6 to specify IPv4 or IPv6 interface.") + group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled. " + "Use -4 or -6 to specify IPv4 or IPv6 interface.") group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface") - group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface") + group.add_argument("-aimld", "--add_interface_mld", nargs=1, metavar='INTERFACE_NAME', help="Add MLD interface") + group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface. " + "Use -4 or -6 to specify IPv4 or IPv6 interface.") group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface") + group.add_argument("-rimld", "--remove_interface_mld", nargs=1, metavar='INTERFACE_NAME', help="Remove MLD interface") group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)") group.add_argument("-t", "--test", nargs=2, metavar=('ROUTER_NAME', 'SERVER_LOG_IP'), help="Tester... send log information to SERVER_LOG_IP. Set the router name to ROUTER_NAME") group.add_argument("--version", action='version', version='%(prog)s ' + VERSION) + group_ipversion = parser.add_mutually_exclusive_group(required=False) + group_ipversion.add_argument("-4", "--ipv4", action="store_true", default=False, help="Setting for IPv4") + group_ipversion.add_argument("-6", "--ipv6", action="store_true", default=False, help="Setting for IPv6") args = parser.parse_args() #print(parser.parse_args()) @@ -137,7 +161,10 @@ def main(): os.system("tail -f /var/log/pimdm/stdout") sys.exit(0) elif args.multicast_routes: - os.system("ip mroute show") + if args.ipv4 or not args.ipv6: + os.system("ip mroute show") + elif args.ipv6: + os.system("ip -6 mroute show") sys.exit(0) elif not daemon.is_running(): print("PIM-DM is not running") diff --git a/pimdm/UnicastRouting.py b/pimdm/UnicastRouting.py index 58af29e..4cbc0f9 100644 --- a/pimdm/UnicastRouting.py +++ b/pimdm/UnicastRouting.py @@ -1,9 +1,8 @@ import socket import ipaddress -from pyroute2 import IPDB from threading import RLock - -from pimdm.utils import if_indextoname +from socket import if_indextoname +from pyroute2 import IPDB def get_route(ip_dst: str): @@ -48,27 +47,34 @@ def check_rpf(ip_dst): @staticmethod def get_route(ip_dst: str): - ip_bytes = socket.inet_aton(ip_dst) - ip_int = int.from_bytes(ip_bytes, byteorder='big') + ip_version = ipaddress.ip_address(ip_dst).version + if ip_version == 4: + family = socket.AF_INET + full_mask = 32 + elif ip_version == 6: + family = socket.AF_INET6 + full_mask = 128 + else: + raise Exception("Unknown IP version") info = None with UnicastRouting.lock: ipdb = UnicastRouting.ipdb # type:IPDB - for mask_len in range(32, 0, -1): - ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big") - ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len) - print(ip_dst) - if ip_dst in ipdb.routes: + for mask_len in range(full_mask, 0, -1): + dst_network = str(ipaddress.ip_interface(ip_dst + "/" + str(mask_len)).network) + + print(dst_network) + if dst_network in ipdb.routes: print(info) - if ipdb.routes[ip_dst]['ipdb_scope'] != 'gc': - info = ipdb.routes[ip_dst] + if ipdb.routes[{'dst': dst_network, 'family': family}]['ipdb_scope'] != 'gc': + info = ipdb.routes[dst_network] break else: continue if not info: - print("0.0.0.0/0") + print("0.0.0.0/0 or ::/0") if "default" in ipdb.routes: - info = ipdb.routes["default"] + info = ipdb.routes[{'dst': 'default', 'family': family}] print(info) return info @@ -85,13 +91,16 @@ def get_unicast_info(ip_dst): oif = unicast_route.get("oif") next_hop = unicast_route["gateway"] multipaths = unicast_route["multipath"] - # prefsrc = unicast_route.get("prefsrc") + #prefsrc = unicast_route.get("prefsrc") - # rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop + #rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop rpf_node = next_hop if next_hop is not None else ip_dst - highest_ip = ipaddress.ip_address("0.0.0.0") + if ipaddress.ip_address(ip_dst).version == 4: + highest_ip = ipaddress.ip_address("0.0.0.0") + else: + highest_ip = ipaddress.ip_address("::") for m in multipaths: - if m["gateway"] is None: + if m.get("gateway", None) is None: oif = m.get('oif') rpf_node = ip_dst break @@ -107,14 +116,22 @@ def get_unicast_info(ip_dst): interface_name = None if oif is None else if_indextoname(int(oif)) from pimdm import Main - rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name) + if ipaddress.ip_address(ip_dst).version == 4: + rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name) + else: + rpf_if = Main.kernel_v6.vif_name_to_index_dic.get(interface_name) return (metric_administrative_distance, metric_cost, rpf_node, rpf_if, mask) @staticmethod def unicast_changes(ipdb, msg, action): + """ + Kernel notified about a change + Verify the type of change and recheck all trees if necessary + """ print("unicast change?") print(action) UnicastRouting.lock.acquire() + family = msg['family'] if action == "RTM_NEWROUTE" or action == "RTM_DELROUTE": print(ipdb.routes) mask_len = msg["dst_len"] @@ -126,8 +143,10 @@ def unicast_changes(ipdb, msg, action): if key == "RTA_DST": network_address = value break - if network_address is None: + if network_address is None and family == socket.AF_INET: network_address = "0.0.0.0" + elif network_address is None and family == socket.AF_INET6: + network_address = "::" print(network_address) print(mask_len) print(network_address + "/" + str(mask_len)) @@ -135,7 +154,10 @@ def unicast_changes(ipdb, msg, action): print(str(subnet)) UnicastRouting.lock.release() from pimdm import Main - Main.kernel.notify_unicast_changes(subnet) + if family == socket.AF_INET: + Main.kernel.notify_unicast_changes(subnet) + elif family == socket.AF_INET6: + Main.kernel_v6.notify_unicast_changes(subnet) ''' elif action == "RTM_NEWADDR" or action == "RTM_DELADDR": print(action) @@ -154,7 +176,7 @@ def unicast_changes(ipdb, msg, action): import traceback traceback.print_exc() pass - bnet = ipaddress.ip_network("0.0.0.0/0") + subnet = ipaddress.ip_network("0.0.0.0/0") Main.kernel.notify_unicast_changes(subnet) elif action == "RTM_NEWLINK" or action == "RTM_DELLINK": attrs = msg["attrs"] @@ -172,7 +194,7 @@ def unicast_changes(ipdb, msg, action): print(if_name + ": " + operation) UnicastRouting.lock.release() if operation == 'DOWN': - Main.kernel.remove_interface(if_name, igmp=True, pim=True) + Main.kernel.remove_interface(if_name, membership=True, pim=True) subnet = ipaddress.ip_network("0.0.0.0/0") Main.kernel.notify_unicast_changes(subnet) ''' @@ -180,6 +202,10 @@ def unicast_changes(ipdb, msg, action): UnicastRouting.lock.release() def stop(self): + """ + No longer monitor unicast changes.... + Invoked whenever the protocol is stopped + """ if self._ipdb: self._ipdb.release() if UnicastRouting.ipdb: diff --git a/pimdm/CustomTimer/RemainingTimer.py b/pimdm/custom_timer/RemainingTimer.py similarity index 100% rename from pimdm/CustomTimer/RemainingTimer.py rename to pimdm/custom_timer/RemainingTimer.py diff --git a/pimdm/CustomTimer/__init__.py b/pimdm/custom_timer/__init__.py similarity index 100% rename from pimdm/CustomTimer/__init__.py rename to pimdm/custom_timer/__init__.py diff --git a/pimdm/Daemon/Daemon.py b/pimdm/daemon/Daemon.py similarity index 100% rename from pimdm/Daemon/Daemon.py rename to pimdm/daemon/Daemon.py diff --git a/pimdm/Daemon/__init__.py b/pimdm/daemon/__init__.py similarity index 100% rename from pimdm/Daemon/__init__.py rename to pimdm/daemon/__init__.py diff --git a/pimdm/igmp/GroupState.py b/pimdm/igmp/GroupState.py index f04b067..f135e88 100644 --- a/pimdm/igmp/GroupState.py +++ b/pimdm/igmp/GroupState.py @@ -129,13 +129,13 @@ def notify_routing_add(self): with self.multicast_interface_state_lock: print("notify+", self.multicast_interface_state) for interface_state in self.multicast_interface_state: - interface_state.notify_igmp(has_members=True) + interface_state.notify_membership(has_members=True) def notify_routing_remove(self): with self.multicast_interface_state_lock: print("notify-", self.multicast_interface_state) for interface_state in self.multicast_interface_state: - interface_state.notify_igmp(has_members=False) + interface_state.notify_membership(has_members=False) def add_multicast_routing_entry(self, kernel_entry): with self.multicast_interface_state_lock: @@ -155,5 +155,5 @@ def remove(self): self.clear_timer() self.clear_v1_host_timer() for interface_state in self.multicast_interface_state: - interface_state.notify_igmp(has_members=False) + interface_state.notify_membership(has_members=False) del self.multicast_interface_state[:] diff --git a/pimdm/igmp/RouterState.py b/pimdm/igmp/RouterState.py index eb00f1c..f86345c 100644 --- a/pimdm/igmp/RouterState.py +++ b/pimdm/igmp/RouterState.py @@ -1,10 +1,10 @@ from threading import Timer import logging -from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader -from pimdm.Packet.ReceivedPacket import ReceivedPacket +from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader +from pimdm.packet.ReceivedPacket import ReceivedPacket from pimdm.utils import TYPE_CHECKING -from pimdm.RWLock.RWLock import RWLockWrite +from pimdm.rwlock.RWLock import RWLockWrite from .querier.Querier import Querier from .nonquerier.NonQuerier import NonQuerier from .GroupState import GroupState diff --git a/pimdm/igmp/nonquerier/NonQuerier.py b/pimdm/igmp/nonquerier/NonQuerier.py index 5b7d96a..90ab892 100644 --- a/pimdm/igmp/nonquerier/NonQuerier.py +++ b/pimdm/igmp/nonquerier/NonQuerier.py @@ -1,8 +1,8 @@ from ipaddress import IPv4Address from pimdm.utils import TYPE_CHECKING from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount -from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader -from pimdm.Packet.ReceivedPacket import ReceivedPacket +from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader +from pimdm.packet.ReceivedPacket import ReceivedPacket from . import NoMembersPresent, MembersPresent, CheckingMembership if TYPE_CHECKING: diff --git a/pimdm/igmp/querier/CheckingMembership.py b/pimdm/igmp/querier/CheckingMembership.py index ecf2ea4..bcb82c7 100644 --- a/pimdm/igmp/querier/CheckingMembership.py +++ b/pimdm/igmp/querier/CheckingMembership.py @@ -1,4 +1,4 @@ -from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader +from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.utils import TYPE_CHECKING from ..igmp_globals import Membership_Query, LastMemberQueryInterval from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent diff --git a/pimdm/igmp/querier/MembersPresent.py b/pimdm/igmp/querier/MembersPresent.py index 2638c2c..279ea89 100644 --- a/pimdm/igmp/querier/MembersPresent.py +++ b/pimdm/igmp/querier/MembersPresent.py @@ -1,4 +1,4 @@ -from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader +from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader from pimdm.utils import TYPE_CHECKING from ..igmp_globals import Membership_Query, LastMemberQueryInterval from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent diff --git a/pimdm/igmp/querier/Querier.py b/pimdm/igmp/querier/Querier.py index aad304a..1f72913 100644 --- a/pimdm/igmp/querier/Querier.py +++ b/pimdm/igmp/querier/Querier.py @@ -3,8 +3,8 @@ from pimdm.utils import TYPE_CHECKING from ..igmp_globals import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval -from pimdm.Packet.PacketIGMPHeader import PacketIGMPHeader -from pimdm.Packet.ReceivedPacket import ReceivedPacket +from pimdm.packet.PacketIGMPHeader import PacketIGMPHeader +from pimdm.packet.ReceivedPacket import ReceivedPacket from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent if TYPE_CHECKING: diff --git a/pimdm/mld/GroupState.py b/pimdm/mld/GroupState.py new file mode 100644 index 0000000..17b41cc --- /dev/null +++ b/pimdm/mld/GroupState.py @@ -0,0 +1,139 @@ +import logging +from threading import Lock +from threading import Timer + +from pimdm.utils import TYPE_CHECKING +from .wrapper import NoListenersPresent +from .mld_globals import MulticastListenerInterval, LastListenerQueryInterval + +if TYPE_CHECKING: + from .RouterState import RouterState + + +class GroupState(object): + LOGGER = logging.getLogger('pim.mld.RouterState.GroupState') + + def __init__(self, router_state: 'RouterState', group_ip: str): + #logger + extra_dict_logger = router_state.router_state_logger.extra.copy() + extra_dict_logger['tree'] = '(*,' + group_ip + ')' + self.group_state_logger = logging.LoggerAdapter(GroupState.LOGGER, extra_dict_logger) + + #timers and state + self.router_state = router_state + self.group_ip = group_ip + self.state = NoListenersPresent + self.timer = None + self.retransmit_timer = None + # lock + self.lock = Lock() + + # KernelEntry's instances to notify change of igmp state + self.multicast_interface_state = [] + self.multicast_interface_state_lock = Lock() + + def print_state(self): + return self.state.print_state() + + ########################################### + # Set state + ########################################### + def set_state(self, state): + self.state = state + self.group_state_logger.debug("change membership state to: " + state.print_state()) + + ########################################### + # Set timers + ########################################### + def set_timer(self, alternative: bool=False, max_response_time: int=None): + self.clear_timer() + if not alternative: + time = MulticastListenerInterval + else: + time = self.router_state.interface_state.get_group_membership_time(max_response_time) + + timer = Timer(time, self.group_membership_timeout) + timer.start() + self.timer = timer + + def clear_timer(self): + if self.timer is not None: + self.timer.cancel() + + def set_retransmit_timer(self): + self.clear_retransmit_timer() + retransmit_timer = Timer(LastListenerQueryInterval, self.retransmit_timeout) + retransmit_timer.start() + self.retransmit_timer = retransmit_timer + + def clear_retransmit_timer(self): + if self.retransmit_timer is not None: + self.retransmit_timer.cancel() + + + ########################################### + # Get group state from specific interface state + ########################################### + def get_interface_group_state(self): + return self.state.get_state(self.router_state) + + ########################################### + # Timer timeout + ########################################### + def group_membership_timeout(self): + with self.lock: + self.get_interface_group_state().group_membership_timeout(self) + + def retransmit_timeout(self): + with self.lock: + self.get_interface_group_state().retransmit_timeout(self) + + ########################################### + # Receive Packets + ########################################### + def receive_report(self): + with self.lock: + self.get_interface_group_state().receive_report(self) + + def receive_done(self): + with self.lock: + self.get_interface_group_state().receive_done(self) + + def receive_group_specific_query(self, max_response_time: int): + with self.lock: + self.get_interface_group_state().receive_group_specific_query(self, max_response_time) + + ########################################### + # Notify Routing + ########################################### + def notify_routing_add(self): + with self.multicast_interface_state_lock: + print("notify+", self.multicast_interface_state) + for interface_state in self.multicast_interface_state: + interface_state.notify_membership(has_members=True) + + def notify_routing_remove(self): + with self.multicast_interface_state_lock: + print("notify-", self.multicast_interface_state) + for interface_state in self.multicast_interface_state: + interface_state.notify_membership(has_members=False) + + def add_multicast_routing_entry(self, kernel_entry): + with self.multicast_interface_state_lock: + self.multicast_interface_state.append(kernel_entry) + return self.has_members() + + def remove_multicast_routing_entry(self, kernel_entry): + with self.multicast_interface_state_lock: + self.multicast_interface_state.remove(kernel_entry) + + def has_members(self): + return self.state is not NoListenersPresent + + def remove(self): + with self.multicast_interface_state_lock: + self.clear_retransmit_timer() + self.clear_timer() + for interface_state in self.multicast_interface_state: + interface_state.notify_membership(has_members=False) + del self.multicast_interface_state[:] diff --git a/pimdm/mld/RouterState.py b/pimdm/mld/RouterState.py new file mode 100644 index 0000000..0554e04 --- /dev/null +++ b/pimdm/mld/RouterState.py @@ -0,0 +1,137 @@ +import logging +from threading import Timer + +from pimdm.packet.PacketMLDHeader import PacketMLDHeader +from pimdm.packet.ReceivedPacket import ReceivedPacket +from pimdm.utils import TYPE_CHECKING +from pimdm.rwlock.RWLock import RWLockWrite +from .querier.Querier import Querier +from .nonquerier.NonQuerier import NonQuerier +from .GroupState import GroupState +from .mld_globals import QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, MULTICAST_LISTENER_QUERY_TYPE + +if TYPE_CHECKING: + from pimdm.InterfaceMLD import InterfaceMLD + + +class RouterState(object): + ROUTER_STATE_LOGGER = logging.getLogger('pim.mld.RouterState') + + def __init__(self, interface: 'InterfaceMLD'): + #logger + logger_extra = dict() + logger_extra['vif'] = interface.vif_index + logger_extra['interfacename'] = interface.interface_name + self.router_state_logger = logging.LoggerAdapter(RouterState.ROUTER_STATE_LOGGER, logger_extra) + + # interface of the router connected to the network + self.interface = interface + + # state of the router (Querier/NonQuerier) + self.interface_state = Querier + + # state of each group + # Key: GroupIPAddress, Value: GroupState object + self.group_state = {} + self.group_state_lock = RWLockWrite() + + # send general query + packet = PacketMLDHeader(type=MULTICAST_LISTENER_QUERY_TYPE, max_resp_delay=QueryResponseInterval*1000) + self.interface.send(packet.bytes()) + + # set initial general query timer + timer = Timer(QueryInterval, self.general_query_timeout) + timer.start() + self.general_query_timer = timer + + # present timer + self.other_querier_present_timer = None + + # Send packet via interface + def send(self, data: bytes, address: str): + self.interface.send(data, address) + + ############################################ + # interface_state methods + ############################################ + def print_state(self): + return self.interface_state.state_name() + + def set_general_query_timer(self): + self.clear_general_query_timer() + general_query_timer = Timer(QueryInterval, self.general_query_timeout) + general_query_timer.start() + self.general_query_timer = general_query_timer + + def clear_general_query_timer(self): + if self.general_query_timer is not None: + self.general_query_timer.cancel() + + def set_other_querier_present_timer(self): + self.clear_other_querier_present_timer() + other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout) + other_querier_present_timer.start() + self.other_querier_present_timer = other_querier_present_timer + + def clear_other_querier_present_timer(self): + if self.other_querier_present_timer is not None: + self.other_querier_present_timer.cancel() + + def general_query_timeout(self): + self.interface_state.general_query_timeout(self) + + def other_querier_present_timeout(self): + self.interface_state.other_querier_present_timeout(self) + + def change_interface_state(self, querier: bool): + if querier: + self.interface_state = Querier + self.router_state_logger.debug('change querier state to -> Querier') + else: + self.interface_state = NonQuerier + self.router_state_logger.debug('change querier state to -> NonQuerier') + + ############################################ + # group state methods + ############################################ + def get_group_state(self, group_ip): + with self.group_state_lock.genRlock(): + if group_ip in self.group_state: + return self.group_state[group_ip] + + with self.group_state_lock.genWlock(): + if group_ip in self.group_state: + group_state = self.group_state[group_ip] + else: + group_state = GroupState(self, group_ip) + self.group_state[group_ip] = group_state + return group_state + + def receive_report(self, packet: ReceivedPacket): + mld_group = packet.payload.group_address + #if igmp_group not in self.group_state: + # self.group_state[igmp_group] = GroupState(self, igmp_group) + + #self.group_state[igmp_group].receive_v2_membership_report() + self.get_group_state(mld_group).receive_report() + + def receive_done(self, packet: ReceivedPacket): + mld_group = packet.payload.group_address + #if igmp_group in self.group_state: + # self.group_state[igmp_group].receive_leave_group() + self.get_group_state(mld_group).receive_done() + + def receive_query(self, packet: ReceivedPacket): + self.interface_state.receive_query(self, packet) + mld_group = packet.payload.group_address + + # process group specific query + if mld_group != "::" and mld_group in self.group_state: + #if igmp_group != "0.0.0.0": + max_response_time = packet.payload.max_resp_delay + #self.group_state[igmp_group].receive_group_specific_query(max_response_time) + self.get_group_state(mld_group).receive_group_specific_query(max_response_time) + + def remove(self): + for group in self.group_state.values(): + group.remove() diff --git a/pimdm/Packet/__init__.py b/pimdm/mld/__init__.py similarity index 100% rename from pimdm/Packet/__init__.py rename to pimdm/mld/__init__.py diff --git a/pimdm/mld/mld_globals.py b/pimdm/mld/mld_globals.py new file mode 100644 index 0000000..882063d --- /dev/null +++ b/pimdm/mld/mld_globals.py @@ -0,0 +1,17 @@ +#MLD timers (in seconds) +RobustnessVariable = 2 +QueryInterval = 125 +QueryResponseInterval = 10 +MulticastListenerInterval = (RobustnessVariable * QueryInterval) + (QueryResponseInterval) +OtherQuerierPresentInterval = (RobustnessVariable * QueryInterval) + 0.5 * QueryResponseInterval +StartupQueryInterval = (1/4) * QueryInterval +StartupQueryCount = RobustnessVariable +LastListenerQueryInterval = 1 +LastListenerQueryCount = RobustnessVariable +UnsolicitedReportInterval = 10 + + +# MLD msg type +MULTICAST_LISTENER_QUERY_TYPE = 130 +MULTICAST_LISTENER_REPORT_TYPE = 131 +MULTICAST_LISTENER_DONE_TYPE = 132 \ No newline at end of file diff --git a/pimdm/mld/nonquerier/CheckingListeners.py b/pimdm/mld/nonquerier/CheckingListeners.py new file mode 100644 index 0000000..8273439 --- /dev/null +++ b/pimdm/mld/nonquerier/CheckingListeners.py @@ -0,0 +1,38 @@ +from pimdm.utils import TYPE_CHECKING +from ..wrapper import NoListenersPresent +from ..wrapper import ListenersPresent + +if TYPE_CHECKING: + from ..GroupState import GroupState + + +def receive_report(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_report') + group_state.set_timer() + group_state.set_state(ListenersPresent) + + +def receive_done(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_done') + # do nothing + return + + +def receive_group_specific_query(group_state: 'GroupState', max_response_time: int): + group_state.group_state_logger.debug('NonQuerier CheckingListeners: receive_group_specific_query') + # do nothing + return + + +def group_membership_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier CheckingListeners: group_membership_timeout') + group_state.set_state(NoListenersPresent) + + # NOTIFY ROUTING - !!!! + group_state.notify_routing_remove() + + +def retransmit_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier CheckingListeners: retransmit_timeout') + # do nothing + return diff --git a/pimdm/mld/nonquerier/ListenersPresent.py b/pimdm/mld/nonquerier/ListenersPresent.py new file mode 100644 index 0000000..9145422 --- /dev/null +++ b/pimdm/mld/nonquerier/ListenersPresent.py @@ -0,0 +1,39 @@ +from pimdm.utils import TYPE_CHECKING +from ..wrapper import NoListenersPresent +from ..wrapper import CheckingListeners + +if TYPE_CHECKING: + from ..GroupState import GroupState + + +def receive_report(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_report') + group_state.set_timer() + + +def receive_done(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_done') + # do nothing + return + + +def receive_group_specific_query(group_state: 'GroupState', max_response_time: int): + group_state.group_state_logger.debug('NonQuerier ListenersPresent: receive_group_specific_query') + group_state.set_timer(alternative=True, max_response_time=max_response_time) + group_state.set_state(CheckingListeners) + + +def group_membership_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier ListenersPresent: group_membership_timeout') + group_state.set_state(NoListenersPresent) + + # NOTIFY ROUTING - !!!! + group_state.notify_routing_remove() + + +def retransmit_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier ListenersPresent: retransmit_timeout') + # do nothing + return + + diff --git a/pimdm/mld/nonquerier/NoListenersPresent.py b/pimdm/mld/nonquerier/NoListenersPresent.py new file mode 100644 index 0000000..c108cc7 --- /dev/null +++ b/pimdm/mld/nonquerier/NoListenersPresent.py @@ -0,0 +1,38 @@ +from pimdm.utils import TYPE_CHECKING +from ..wrapper import ListenersPresent + +if TYPE_CHECKING: + from ..GroupState import GroupState + + +def receive_report(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_report') + group_state.set_timer() + group_state.set_state(ListenersPresent) + + # NOTIFY ROUTING + !!!! + group_state.notify_routing_add() + + +def receive_done(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_done') + # do nothing + return + + +def receive_group_specific_query(group_state: 'GroupState', max_response_time: int): + group_state.group_state_logger.debug('NonQuerier NoListenersPresent: receive_group_specific_query') + # do nothing + return + + +def group_membership_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier NoListenersPresent: group_membership_timeout') + # do nothing + return + + +def retransmit_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('NonQuerier NoListenersPresent: retransmit_timeout') + # do nothing + return diff --git a/pimdm/mld/nonquerier/NonQuerier.py b/pimdm/mld/nonquerier/NonQuerier.py new file mode 100644 index 0000000..222af7b --- /dev/null +++ b/pimdm/mld/nonquerier/NonQuerier.py @@ -0,0 +1,65 @@ +from ipaddress import IPv6Address +from pimdm.utils import TYPE_CHECKING +from ..mld_globals import QueryResponseInterval, LastListenerQueryCount +from pimdm.packet.PacketMLDHeader import PacketMLDHeader +from pimdm.packet.ReceivedPacket import ReceivedPacket +from . import NoListenersPresent, ListenersPresent, CheckingListeners + +if TYPE_CHECKING: + from ..RouterState import RouterState + + +class NonQuerier: + @staticmethod + def general_query_timeout(router_state: 'RouterState'): + router_state.router_state_logger.debug('NonQuerier state: general_query_timeout') + # do nothing + return + + @staticmethod + def other_querier_present_timeout(router_state: 'RouterState'): + router_state.router_state_logger.debug('NonQuerier state: other_querier_present_timeout') + #change state to Querier + router_state.change_interface_state(querier=True) + + # send general query + packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE, + max_resp_delay=QueryResponseInterval*1000) + router_state.interface.send(packet.bytes()) + + # set general query timer + router_state.set_general_query_timer() + + @staticmethod + def receive_query(router_state: 'RouterState', packet: ReceivedPacket): + router_state.router_state_logger.debug('NonQuerier state: receive_query') + source_ip = packet.ip_header.ip_src + + # if source ip of membership query not lower than the ip of the received interface => ignore + if IPv6Address(source_ip) >= IPv6Address(router_state.interface.get_ip()): + return + + # reset other present querier timer + router_state.set_other_querier_present_timer() + + # TODO ver se existe uma melhor maneira de fazer isto + @staticmethod + def state_name(): + return "Non Querier" + + @staticmethod + def get_group_membership_time(max_response_time: int): + return (max_response_time/1000.0) * LastListenerQueryCount + + # State + @staticmethod + def get_checking_listeners_state(): + return CheckingListeners + + @staticmethod + def get_listeners_present_state(): + return ListenersPresent + + @staticmethod + def get_no_listeners_present_state(): + return NoListenersPresent diff --git a/pimdm/RWLock/__init__.py b/pimdm/mld/nonquerier/__init__.py similarity index 100% rename from pimdm/RWLock/__init__.py rename to pimdm/mld/nonquerier/__init__.py diff --git a/pimdm/mld/querier/CheckingListeners.py b/pimdm/mld/querier/CheckingListeners.py new file mode 100644 index 0000000..2fb8431 --- /dev/null +++ b/pimdm/mld/querier/CheckingListeners.py @@ -0,0 +1,43 @@ +from pimdm.packet.PacketMLDHeader import PacketMLDHeader +from pimdm.utils import TYPE_CHECKING +from ..mld_globals import LastListenerQueryInterval +from ..wrapper import ListenersPresent, NoListenersPresent +if TYPE_CHECKING: + from ..GroupState import GroupState + + +def receive_report(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier CheckingListeners: receive_report') + group_state.set_timer() + group_state.clear_retransmit_timer() + group_state.set_state(ListenersPresent) + +def receive_done(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier CheckingListeners: receive_done') + # do nothing + return + + +def receive_group_specific_query(group_state: 'GroupState', max_response_time: int): + group_state.group_state_logger.debug('Querier CheckingListeners: receive_group_specific_query') + # do nothing + return + + +def group_membership_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier CheckingListeners: group_membership_timeout') + group_state.clear_retransmit_timer() + group_state.set_state(NoListenersPresent) + + # NOTIFY ROUTING - !!!! + group_state.notify_routing_remove() + + +def retransmit_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier CheckingListeners: retransmit_timeout') + group_addr = group_state.group_ip + packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE, + max_resp_delay=LastListenerQueryInterval*1000, group_address=group_addr) + group_state.router_state.send(data=packet.bytes(), address=group_addr) + + group_state.set_retransmit_timer() diff --git a/pimdm/mld/querier/ListenersPresent.py b/pimdm/mld/querier/ListenersPresent.py new file mode 100644 index 0000000..d939fdf --- /dev/null +++ b/pimdm/mld/querier/ListenersPresent.py @@ -0,0 +1,47 @@ +from pimdm.packet.PacketMLDHeader import PacketMLDHeader +from pimdm.utils import TYPE_CHECKING +from ..mld_globals import LastListenerQueryInterval +from ..wrapper import CheckingListeners, NoListenersPresent +if TYPE_CHECKING: + from ..GroupState import GroupState + + +def receive_report(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier ListenersPresent: receive_report') + group_state.set_timer() + + +def receive_done(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier ListenersPresent: receive_done') + group_ip = group_state.group_ip + + group_state.set_timer(alternative=True) + group_state.set_retransmit_timer() + + packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE, + max_resp_delay=LastListenerQueryInterval*1000, group_address=group_ip) + group_state.router_state.send(data=packet.bytes(), address=group_ip) + + group_state.set_state(CheckingListeners) + + +def receive_group_specific_query(group_state: 'GroupState', max_response_time): + group_state.group_state_logger.debug('Querier ListenersPresent: receive_group_specific_query') + # do nothing + return + + +def group_membership_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier ListenersPresent: group_membership_timeout') + group_state.set_state(NoListenersPresent) + + # NOTIFY ROUTING - !!!! + group_state.notify_routing_remove() + + +def retransmit_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier ListenersPresent: retransmit_timeout') + # do nothing + return + + diff --git a/pimdm/mld/querier/NoListenersPresent.py b/pimdm/mld/querier/NoListenersPresent.py new file mode 100644 index 0000000..b40b0e7 --- /dev/null +++ b/pimdm/mld/querier/NoListenersPresent.py @@ -0,0 +1,38 @@ +from pimdm.utils import TYPE_CHECKING +from ..wrapper import ListenersPresent +if TYPE_CHECKING: + from ..GroupState import GroupState + + +def receive_report(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier NoListenersPresent: receive_report') + group_state.set_timer() + group_state.set_state(ListenersPresent) + + # NOTIFY ROUTING + !!!! + group_state.notify_routing_add() + + +def receive_done(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier NoListenersPresent: receive_done') + # do nothing + return + + +def receive_group_specific_query(group_state: 'GroupState', max_response_time: int): + group_state.group_state_logger.debug('Querier NoListenersPresent: receive_group_specific_query') + # do nothing + return + + +def group_membership_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier NoListenersPresent: group_membership_timeout') + # do nothing + return + + +def retransmit_timeout(group_state: 'GroupState'): + group_state.group_state_logger.debug('Querier NoListenersPresent: retransmit_timeout') + # do nothing + return + diff --git a/pimdm/mld/querier/Querier.py b/pimdm/mld/querier/Querier.py new file mode 100644 index 0000000..f4c3fef --- /dev/null +++ b/pimdm/mld/querier/Querier.py @@ -0,0 +1,70 @@ +from ipaddress import IPv6Address + +from pimdm.utils import TYPE_CHECKING +from ..mld_globals import LastListenerQueryInterval, LastListenerQueryCount, QueryResponseInterval + +from pimdm.packet.PacketMLDHeader import PacketMLDHeader +from pimdm.packet.ReceivedPacket import ReceivedPacket +from . import CheckingListeners, ListenersPresent, NoListenersPresent + +if TYPE_CHECKING: + from ..RouterState import RouterState + + +class Querier: + @staticmethod + def general_query_timeout(router_state: 'RouterState'): + router_state.router_state_logger.debug('Querier state: general_query_timeout') + # send general query + packet = PacketMLDHeader(type=PacketMLDHeader.MULTICAST_LISTENER_QUERY_TYPE, + max_resp_delay=QueryResponseInterval*1000) + router_state.interface.send(packet.bytes()) + + # set general query timer + router_state.set_general_query_timer() + + @staticmethod + def receive_query(router_state: 'RouterState', packet: ReceivedPacket): + router_state.router_state_logger.debug('Querier state: receive_query') + source_ip = packet.ip_header.ip_src + + # if source ip of membership query not lower than the ip of the received interface => ignore + if IPv6Address(source_ip) >= IPv6Address(router_state.interface.get_ip()): + return + + # if source ip of membership query lower than the ip of the received interface => change state + # change state of interface + # Querier -> Non Querier + router_state.change_interface_state(querier=False) + + # set other present querier timer + router_state.clear_general_query_timer() + router_state.set_other_querier_present_timer() + + @staticmethod + def other_querier_present_timeout(router_state: 'RouterState'): + router_state.router_state_logger.debug('Querier state: other_querier_present_timeout') + # do nothing + return + + # TODO ver se existe uma melhor maneira de fazer isto + @staticmethod + def state_name(): + return "Querier" + + @staticmethod + def get_group_membership_time(max_response_time: int): + return LastListenerQueryInterval * LastListenerQueryCount + + # State + @staticmethod + def get_checking_listeners_state(): + return CheckingListeners + + @staticmethod + def get_listeners_present_state(): + return ListenersPresent + + @staticmethod + def get_no_listeners_present_state(): + return NoListenersPresent diff --git a/pimdm/mld/querier/__init__.py b/pimdm/mld/querier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pimdm/mld/wrapper/CheckingListeners.py b/pimdm/mld/wrapper/CheckingListeners.py new file mode 100644 index 0000000..eca9ea4 --- /dev/null +++ b/pimdm/mld/wrapper/CheckingListeners.py @@ -0,0 +1,40 @@ +from pimdm.utils import TYPE_CHECKING +if TYPE_CHECKING: + from ..RouterState import RouterState + + +def get_state(router_state: 'RouterState'): + return router_state.interface_state.get_checking_listeners_state() + + +def print_state(): + return "CheckingListeners" +''' +def group_membership_timeout(group_state): + get_state(group_state).group_membership_timeout(group_state) + + +def group_membership_v1_timeout(group_state): + get_state(group_state).group_membership_v1_timeout(group_state) + + +def retransmit_timeout(group_state): + get_state(group_state).retransmit_timeout(group_state) + + +def receive_v1_membership_report(group_state, packet: ReceivedPacket): + get_state(group_state).receive_v1_membership_report(group_state, packet) + + +def receive_v2_membership_report(group_state, packet: ReceivedPacket): + get_state(group_state).receive_v2_membership_report(group_state, packet) + + +def receive_leave_group(group_state, packet: ReceivedPacket): + get_state(group_state).receive_leave_group(group_state, packet) + + +def receive_group_specific_query(group_state, packet: ReceivedPacket): + get_state(group_state).receive_group_specific_query(group_state, packet) + +''' \ No newline at end of file diff --git a/pimdm/mld/wrapper/ListenersPresent.py b/pimdm/mld/wrapper/ListenersPresent.py new file mode 100644 index 0000000..81e4605 --- /dev/null +++ b/pimdm/mld/wrapper/ListenersPresent.py @@ -0,0 +1,41 @@ +from pimdm.utils import TYPE_CHECKING +if TYPE_CHECKING: + from ..RouterState import RouterState + + +def get_state(router_state: 'RouterState'): + return router_state.interface_state.get_listeners_present_state() + + +def print_state(): + return "ListenersPresent" + +''' +def group_membership_timeout(group_state): + get_state(group_state).group_membership_timeout(group_state) + + +def group_membership_v1_timeout(group_state): + get_state(group_state).group_membership_v1_timeout(group_state) + + +def retransmit_timeout(group_state): + get_state(group_state).retransmit_timeout(group_state) + + +def receive_v1_membership_report(group_state, packet: ReceivedPacket): + get_state(group_state).receive_v1_membership_report(group_state, packet) + + +def receive_v2_membership_report(group_state, packet: ReceivedPacket): + get_state(group_state).receive_v2_membership_report(group_state, packet) + + +def receive_leave_group(group_state, packet: ReceivedPacket): + get_state(group_state).receive_leave_group(group_state, packet) + + +def receive_group_specific_query(group_state, packet: ReceivedPacket): + get_state(group_state).receive_group_specific_query(group_state, packet) + +''' \ No newline at end of file diff --git a/pimdm/mld/wrapper/NoListenersPresent.py b/pimdm/mld/wrapper/NoListenersPresent.py new file mode 100644 index 0000000..063c81c --- /dev/null +++ b/pimdm/mld/wrapper/NoListenersPresent.py @@ -0,0 +1,39 @@ +from pimdm.utils import TYPE_CHECKING +if TYPE_CHECKING: + from ..RouterState import RouterState + + +def get_state(router_state: 'RouterState'): + return router_state.interface_state.get_no_listeners_present_state() + + +def print_state(): + return "NoListenersPresent" +''' +def group_membership_timeout(group_state): + get_state(group_state).group_membership_timeout(group_state) + + +def group_membership_v1_timeout(group_state): + get_state(group_state).group_membership_v1_timeout(group_state) + + +def retransmit_timeout(group_state): + get_state(group_state).retransmit_timeout(group_state) + + +def receive_v1_membership_report(group_state, packet: ReceivedPacket): + get_state(group_state).receive_v1_membership_report(group_state, packet) + + +def receive_v2_membership_report(group_state, packet: ReceivedPacket): + get_state(group_state).receive_v2_membership_report(group_state, packet) + + +def receive_leave_group(group_state, packet: ReceivedPacket): + get_state(group_state).receive_leave_group(group_state, packet) + + +def receive_group_specific_query(group_state, packet: ReceivedPacket): + get_state(group_state).receive_group_specific_query(group_state, packet) +''' diff --git a/pimdm/mld/wrapper/__init__.py b/pimdm/mld/wrapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pimdm/Packet/Packet.py b/pimdm/packet/Packet.py similarity index 100% rename from pimdm/Packet/Packet.py rename to pimdm/packet/Packet.py diff --git a/pimdm/Packet/PacketIGMPHeader.py b/pimdm/packet/PacketIGMPHeader.py similarity index 100% rename from pimdm/Packet/PacketIGMPHeader.py rename to pimdm/packet/PacketIGMPHeader.py diff --git a/pimdm/packet/PacketIpHeader.py b/pimdm/packet/PacketIpHeader.py new file mode 100644 index 0000000..75fb56a --- /dev/null +++ b/pimdm/packet/PacketIpHeader.py @@ -0,0 +1,155 @@ +import struct +import socket + + +class PacketIpHeader: + """ + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |Version| + +-+-+-+-+ + """ + IP_HDR = "! B" + IP_HDR_LEN = struct.calcsize(IP_HDR) + + def __init__(self, ver, hdr_len): + self.version = ver + self.hdr_length = hdr_len + + def __len__(self): + return self.hdr_length + + @staticmethod + def parse_bytes(data: bytes): + (verhlen, ) = struct.unpack(PacketIpHeader.IP_HDR, data[:PacketIpHeader.IP_HDR_LEN]) + ver = (verhlen & 0xF0) >> 4 + print("ver:", ver) + return PACKET_HEADER.get(ver).parse_bytes(data) + + +class PacketIpv4Header(PacketIpHeader): + """ + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |Version| IHL |Type of Service| Total Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Identification |Flags| Fragment Offset | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Time to Live | Protocol | Header Checksum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Address | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Destination Address | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Options | Padding | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + """ + IP_HDR = "! BBH HH BBH 4s 4s" + IP_HDR_LEN = struct.calcsize(IP_HDR) + + def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst): + super().__init__(ver, hdr_len) + self.ttl = ttl + self.proto = proto + self.ip_src = ip_src + self.ip_dst = ip_dst + + def __len__(self): + return self.hdr_length + + @staticmethod + def parse_bytes(data: bytes): + (verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \ + struct.unpack(PacketIpv4Header.IP_HDR, data[:PacketIpv4Header.IP_HDR_LEN]) + + ver = (verhlen & 0xf0) >> 4 + hlen = (verhlen & 0x0f) * 4 + + ''' + "VER": ver, + "HLEN": hlen, + "TOS": tos, + "IPLEN": iplen, + "IPID": ipid, + "FRAG": frag, + "TTL": ttl, + "PROTO": proto, + "CKSUM": cksum, + "SRC": socket.inet_ntoa(src), + "DST": socket.inet_ntoa(dst) + ''' + src_ip = socket.inet_ntoa(src) + dst_ip = socket.inet_ntoa(dst) + return PacketIpv4Header(ver, hlen, ttl, proto, src_ip, dst_ip) + + +class PacketIpv6Header(PacketIpHeader): + """ + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |Version| Traffic Class | Flow Label | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Payload Length | Next Header | Hop Limit | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + + + + | | + + Source Address + + | | + + + + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + + + + | | + + Destination Address + + | | + + + + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + """ + IP6_HDR = "! I HBB 16s 16s" + IP6_HDR_LEN = struct.calcsize(IP6_HDR) + + def __init__(self, ver, next_header, hop_limit, ip_src, ip_dst): + # TODO: confirm hdr_length in case of multiple options/headers + super().__init__(ver, PacketIpv6Header.IP6_HDR_LEN) + self.next_header = next_header + self.hop_limit = hop_limit + self.ip_src = ip_src + self.ip_dst = ip_dst + + def __len__(self): + return PacketIpv6Header.IP6_HDR_LEN + + @staticmethod + def parse_bytes(data: bytes): + (ver_tc_fl, _, next_header, hop_limit, src, dst) = \ + struct.unpack(PacketIpv6Header.IP6_HDR, data[:PacketIpv6Header.IP6_HDR_LEN]) + + ver = (ver_tc_fl & 0xf0000000) >> 28 + #tc = (ver_tc_fl & 0x0ff00000) >> 20 + #fl = (ver_tc_fl & 0x000fffff) + ''' + "VER": ver, + "TRAFFIC CLASS": tc, + "FLOW LABEL": fl, + "PAYLOAD LEN": payload_length, + "NEXT HEADER": next_header, + "HOP LIMIT": hop_limit, + "SRC": socket.inet_atop(socket.AF_INET6, src), + "DST": socket.inet_atop(socket.AF_INET6, dst) + ''' + + src_ip = socket.inet_ntop(socket.AF_INET6, src) + dst_ip = socket.inet_ntop(socket.AF_INET6, dst) + return PacketIpv6Header(ver, next_header, hop_limit, src_ip, dst_ip) + + +PACKET_HEADER = { + 4: PacketIpv4Header, + 6: PacketIpv6Header, +} diff --git a/pimdm/packet/PacketMLDHeader.py b/pimdm/packet/PacketMLDHeader.py new file mode 100644 index 0000000..1f79190 --- /dev/null +++ b/pimdm/packet/PacketMLDHeader.py @@ -0,0 +1,64 @@ +import struct +import socket +from .PacketPayload import PacketPayload + +""" + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type | Code | Checksum | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Maximum Response Delay | Reserved | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | ++ + +| | ++ Multicast Address + +| | ++ + +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +""" +class PacketMLDHeader(PacketPayload): + MLD_TYPE = 58 + + MLD_HDR = "! BB H H H 16s" + MLD_HDR_LEN = struct.calcsize(MLD_HDR) + + MULTICAST_LISTENER_QUERY_TYPE = 130 + MULTICAST_LISTENER_REPORT_TYPE = 131 + MULTICAST_LISTENER_DONE_TYPE = 132 + + def __init__(self, type: int, max_resp_delay: int, group_address: str = "::"): + # todo check type + self.type = type + self.max_resp_delay = max_resp_delay + self.group_address = group_address + + def get_mld_type(self): + return self.type + + def bytes(self) -> bytes: + # obter mensagem e criar checksum + msg_without_chcksum = struct.pack(PacketMLDHeader.MLD_HDR, self.type, 0, 0, self.max_resp_delay, 0, + socket.inet_pton(socket.AF_INET6, self.group_address)) + #mld_checksum = checksum(msg_without_chcksum) + #msg = msg_without_chcksum[0:2] + struct.pack("! H", mld_checksum) + msg_without_chcksum[4:] + # checksum handled by linux kernel + return msg_without_chcksum + + def __len__(self): + return len(self.bytes()) + + + @staticmethod + def parse_bytes(data: bytes): + mld_hdr = data[0:PacketMLDHeader.MLD_HDR_LEN] + if len(mld_hdr) < PacketMLDHeader.MLD_HDR_LEN: + raise Exception("MLD packet length is lower than expected") + (mld_type, _, _, max_resp_delay, _, group_address) = struct.unpack(PacketMLDHeader.MLD_HDR, mld_hdr) + # checksum is handled by linux kernel + mld_hdr = mld_hdr[PacketMLDHeader.MLD_HDR_LEN:] + group_address = socket.inet_ntop(socket.AF_INET6, group_address) + pkt = PacketMLDHeader(mld_type, max_resp_delay, group_address) + return pkt diff --git a/pimdm/Packet/PacketPayload.py b/pimdm/packet/PacketPayload.py similarity index 100% rename from pimdm/Packet/PacketPayload.py rename to pimdm/packet/PacketPayload.py diff --git a/pimdm/Packet/PacketPimAssert.py b/pimdm/packet/PacketPimAssert.py similarity index 100% rename from pimdm/Packet/PacketPimAssert.py rename to pimdm/packet/PacketPimAssert.py diff --git a/pimdm/Packet/PacketPimEncodedGroupAddress.py b/pimdm/packet/PacketPimEncodedGroupAddress.py similarity index 95% rename from pimdm/Packet/PacketPimEncodedGroupAddress.py rename to pimdm/packet/PacketPimEncodedGroupAddress.py index 0849d4e..a49c47d 100644 --- a/pimdm/Packet/PacketPimEncodedGroupAddress.py +++ b/pimdm/packet/PacketPimEncodedGroupAddress.py @@ -55,7 +55,7 @@ def get_ip_info(ip): elif version == 6: return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6) else: - raise Exception + raise Exception("Unknown address family") def __len__(self): version = ipaddress.ip_address(self.group_address).version @@ -64,7 +64,7 @@ def __len__(self): elif version == 6: return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6 else: - raise Exception + raise Exception("Unknown address family") @staticmethod def parse_bytes(data: bytes): @@ -72,13 +72,14 @@ def parse_bytes(data: bytes): (addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr) data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:] - ip = None if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4: (ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4]) ip = socket.inet_ntop(socket.AF_INET, ip) elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6: (ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16]) ip = socket.inet_ntop(socket.AF_INET6, ip) + else: + raise Exception("Unknown address family") if encoding != 0: print("unknown encoding") diff --git a/pimdm/Packet/PacketPimEncodedSourceAddress.py b/pimdm/packet/PacketPimEncodedSourceAddress.py similarity index 95% rename from pimdm/Packet/PacketPimEncodedSourceAddress.py rename to pimdm/packet/PacketPimEncodedSourceAddress.py index 2227407..b6bba07 100644 --- a/pimdm/Packet/PacketPimEncodedSourceAddress.py +++ b/pimdm/packet/PacketPimEncodedSourceAddress.py @@ -57,7 +57,7 @@ def get_ip_info(ip): elif version == 6: return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6) else: - raise Exception + raise Exception("Unknown address family") def __len__(self): version = ipaddress.ip_address(self.source_address).version @@ -66,7 +66,7 @@ def __len__(self): elif version == 6: return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6 else: - raise Exception + raise Exception("Unknown address family") @staticmethod def parse_bytes(data: bytes): @@ -74,13 +74,14 @@ def parse_bytes(data: bytes): (addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr) data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:] - ip = None if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4: (ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4]) ip = socket.inet_ntop(socket.AF_INET, ip) elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6: (ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16]) ip = socket.inet_ntop(socket.AF_INET6, ip) + else: + raise Exception("Unknown address family") if encoding != 0: print("unknown encoding") diff --git a/pimdm/Packet/PacketPimEncodedUnicastAddress.py b/pimdm/packet/PacketPimEncodedUnicastAddress.py similarity index 95% rename from pimdm/Packet/PacketPimEncodedUnicastAddress.py rename to pimdm/packet/PacketPimEncodedUnicastAddress.py index 6f23126..794e46e 100644 --- a/pimdm/Packet/PacketPimEncodedUnicastAddress.py +++ b/pimdm/packet/PacketPimEncodedUnicastAddress.py @@ -46,7 +46,7 @@ def get_ip_info(ip): elif version == 6: return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6) else: - raise Exception + raise Exception("Unknown address family") def __len__(self): version = ipaddress.ip_address(self.unicast_address).version @@ -55,7 +55,7 @@ def __len__(self): elif version == 6: return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6 else: - raise Exception + raise Exception("Unknown address family") @staticmethod def parse_bytes(data: bytes): @@ -69,6 +69,8 @@ def parse_bytes(data: bytes): elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6: (ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16]) ip = socket.inet_ntop(socket.AF_INET6, ip) + else: + raise Exception("Unknown address family") if encoding != 0: print("unknown encoding") diff --git a/pimdm/Packet/PacketPimGraft.py b/pimdm/packet/PacketPimGraft.py similarity index 100% rename from pimdm/Packet/PacketPimGraft.py rename to pimdm/packet/PacketPimGraft.py diff --git a/pimdm/Packet/PacketPimGraftAck.py b/pimdm/packet/PacketPimGraftAck.py similarity index 100% rename from pimdm/Packet/PacketPimGraftAck.py rename to pimdm/packet/PacketPimGraftAck.py diff --git a/pimdm/Packet/PacketPimHeader.py b/pimdm/packet/PacketPimHeader.py similarity index 100% rename from pimdm/Packet/PacketPimHeader.py rename to pimdm/packet/PacketPimHeader.py diff --git a/pimdm/Packet/PacketPimHello.py b/pimdm/packet/PacketPimHello.py similarity index 100% rename from pimdm/Packet/PacketPimHello.py rename to pimdm/packet/PacketPimHello.py diff --git a/pimdm/Packet/PacketPimHelloOptions.py b/pimdm/packet/PacketPimHelloOptions.py similarity index 100% rename from pimdm/Packet/PacketPimHelloOptions.py rename to pimdm/packet/PacketPimHelloOptions.py diff --git a/pimdm/Packet/PacketPimJoinPrune.py b/pimdm/packet/PacketPimJoinPrune.py similarity index 100% rename from pimdm/Packet/PacketPimJoinPrune.py rename to pimdm/packet/PacketPimJoinPrune.py diff --git a/pimdm/Packet/PacketPimJoinPruneMulticastGroup.py b/pimdm/packet/PacketPimJoinPruneMulticastGroup.py similarity index 100% rename from pimdm/Packet/PacketPimJoinPruneMulticastGroup.py rename to pimdm/packet/PacketPimJoinPruneMulticastGroup.py diff --git a/pimdm/Packet/PacketPimStateRefresh.py b/pimdm/packet/PacketPimStateRefresh.py similarity index 100% rename from pimdm/Packet/PacketPimStateRefresh.py rename to pimdm/packet/PacketPimStateRefresh.py diff --git a/pimdm/packet/ReceivedPacket.py b/pimdm/packet/ReceivedPacket.py new file mode 100644 index 0000000..5922302 --- /dev/null +++ b/pimdm/packet/ReceivedPacket.py @@ -0,0 +1,46 @@ +import socket +from .Packet import Packet +from .PacketPimHeader import PacketPimHeader +from .PacketMLDHeader import PacketMLDHeader +from .PacketIGMPHeader import PacketIGMPHeader +from .PacketIpHeader import PacketIpv4Header, PacketIpv6Header +from pimdm.utils import TYPE_CHECKING +if TYPE_CHECKING: + from pimdm.Interface import Interface + + +class ReceivedPacket(Packet): + # choose payload protocol class based on ip protocol number + payload_protocol = {2: PacketIGMPHeader, 103: PacketPimHeader} + + def __init__(self, raw_packet: bytes, interface: 'Interface'): + self.interface = interface + + # Parse packet and fill Packet super class + ip_header = PacketIpv4Header.parse_bytes(raw_packet) + protocol_number = ip_header.proto + + packet_without_ip_hdr = raw_packet[ip_header.hdr_length:] + payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr) + + super().__init__(ip_header=ip_header, payload=payload) + + +class ReceivedPacket_v6(Packet): + # choose payload protocol class based on ip protocol number + payload_protocol_v6 = {58: PacketMLDHeader, 103: PacketPimHeader} + + def __init__(self, raw_packet: bytes, ancdata: list, src_addr: str, next_header: int, interface: 'Interface'): + self.interface = interface + + # Parse packet and fill Packet super class + dst_addr = "::" + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.IPPROTO_IPV6 and cmsg_type == socket.IPV6_PKTINFO: + dst_addr = socket.inet_ntop(socket.AF_INET6, cmsg_data[:16]) + break + + src_addr = src_addr[0].split("%")[0] + ipv6_packet = PacketIpv6Header(ver=6, hop_limit=1, next_header=next_header, ip_src=src_addr, ip_dst=dst_addr) + payload = ReceivedPacket_v6.payload_protocol_v6[next_header].parse_bytes(raw_packet) + super().__init__(ip_header=ipv6_packet, payload=payload) diff --git a/pimdm/packet/__init__.py b/pimdm/packet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pimdm/RWLock/RWLock.py b/pimdm/rwlock/RWLock.py similarity index 100% rename from pimdm/RWLock/RWLock.py rename to pimdm/rwlock/RWLock.py diff --git a/pimdm/rwlock/__init__.py b/pimdm/rwlock/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pimdm/tree/KernelEntry.py b/pimdm/tree/KernelEntry.py index 303f8bd..e6d62ca 100644 --- a/pimdm/tree/KernelEntry.py +++ b/pimdm/tree/KernelEntry.py @@ -1,23 +1,27 @@ +import logging +from time import time +from threading import Lock, RLock + +from pimdm import UnicastRouting +from .metric import AssertMetric from .tree_if_upstream import TreeInterfaceUpstream from .tree_if_downstream import TreeInterfaceDownstream from .tree_interface import TreeInterface -from threading import Lock, RLock -from .metric import AssertMetric -from pimdm import UnicastRouting, Main -from time import time -import logging + class KernelEntry: - TREE_TIMEOUT = 180 KERNEL_LOGGER = logging.getLogger('pim.KernelEntry') - def __init__(self, source_ip: str, group_ip: str): - self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER, {'tree': '(' + source_ip + ',' + group_ip + ')'}) + def __init__(self, source_ip: str, group_ip: str, kernel_entry_interface): + self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER, + {'tree': '(' + source_ip + ',' + group_ip + ')'}) self.kernel_entry_logger.debug('Create KernelEntry') self.source_ip = source_ip self.group_ip = group_ip + self._kernel_entry_interface = kernel_entry_interface + # OBTAIN UNICAST ROUTING INFORMATION################################################### (metric_administrative_distance, metric_cost, rpf_node, root_if, mask) = \ UnicastRouting.get_unicast_info(source_ip) @@ -38,7 +42,7 @@ def __init__(self, source_ip: str, group_ip: str): self.interface_state = {} # type: Dict[int, TreeInterface] with self.CHANGE_STATE_LOCK: - for i in Main.kernel.vif_index_to_name_dic.keys(): + for i in self.get_kernel().vif_index_to_name_dic.keys(): try: if i == self.inbound_interface_index: self.interface_state[i] = TreeInterfaceUpstream(self, i) @@ -55,22 +59,31 @@ def __init__(self, source_ip: str, group_ip: str): print('Tree created') def get_inbound_interface_index(self): + """ + Get VIF of root interface of this tree + """ return self.inbound_interface_index def get_outbound_interfaces_indexes(self): - outbound_indexes = [0] * Main.kernel.MAXVIFS - for (index, state) in self.interface_state.items(): - outbound_indexes[index] = state.is_forwarding() - return outbound_indexes + """ + Get OIL of this tree + """ + return self._kernel_entry_interface.get_outbound_interfaces_indexes(self) ################################################ # Receive (S,G) data packets or control packets ################################################ def recv_data_msg(self, index): + """ + Receive data packet regarding this tree in interface with VIF index + """ print("recv data") self.interface_state[index].recv_data_msg() def recv_assert_msg(self, index, packet): + """ + Receive assert packet regarding this tree in interface with VIF index + """ print("recv assert") pkt_assert = packet.payload.payload metric = pkt_assert.metric @@ -81,28 +94,43 @@ def recv_assert_msg(self, index, packet): self.interface_state[index].recv_assert_msg(received_metric) def recv_prune_msg(self, index, packet): + """ + Receive Prune packet regarding this tree in interface with VIF index + """ print("recv prune msg") holdtime = packet.payload.payload.hold_time upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address self.interface_state[index].recv_prune_msg(upstream_neighbor_address=upstream_neighbor_address, holdtime=holdtime) def recv_join_msg(self, index, packet): + """ + Receive Join packet regarding this tree in interface with VIF index + """ print("recv join msg") upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address self.interface_state[index].recv_join_msg(upstream_neighbor_address) def recv_graft_msg(self, index, packet): + """ + Receive Graft packet regarding this tree in interface with VIF index + """ print("recv graft msg") upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address source_ip = packet.ip_header.ip_src self.interface_state[index].recv_graft_msg(upstream_neighbor_address, source_ip) def recv_graft_ack_msg(self, index, packet): + """ + Receive GraftAck packet regarding this tree in interface with VIF index + """ print("recv graft ack msg") source_ip = packet.ip_header.ip_src self.interface_state[index].recv_graft_ack_msg(source_ip) def recv_state_refresh_msg(self, index, packet): + """ + Receive StateRefresh packet regarding this tree in interface with VIF index + """ print("recv state refresh msg") source_of_state_refresh = packet.ip_header.ip_src @@ -129,11 +157,13 @@ def recv_state_refresh_msg(self, index, packet): self.forward_state_refresh_msg(packet.payload.payload) - ################################################ # Send state refresh msg ################################################ def forward_state_refresh_msg(self, state_refresh_packet): + """ + Forward StateRefresh packet through all interfaces + """ for interface in self.interface_state.values(): interface.send_state_refresh(state_refresh_packet) @@ -142,6 +172,9 @@ def forward_state_refresh_msg(self, state_refresh_packet): # Unicast Changes to RPF ############################################################### def network_update(self): + """ + Unicast routing table suffered an update and this tree might be affected by it + """ # TODO TALVEZ OUTRO LOCK PARA BLOQUEAR ENTRADA DE PACOTES with self.CHANGE_STATE_LOCK: @@ -184,24 +217,34 @@ def network_update(self): self.rpf_node = rpf_node self.interface_state[self.inbound_interface_index].change_on_unicast_routing() - - # check if add/removal of neighbors from interface afects olist and forward/prune state of interface def change_at_number_of_neighbors(self): + """ + Check if modification of number of neighbors causes changes to OIL and interest of interface + """ with self.CHANGE_STATE_LOCK: self.change() self.evaluate_olist_change() def new_or_reset_neighbor(self, if_index, neighbor_ip): + """ + An interface identified by if_index has a new neighbor + """ # todo maybe lock de interfaces self.interface_state[if_index].new_or_reset_neighbor(neighbor_ip) def is_olist_null(self): + """ + Check if olist is null + """ for interface in self.interface_state.values(): if interface.is_forwarding(): return False return True def evaluate_olist_change(self): + """ + React to changes on the olist + """ with self._lock_test2: is_olist_null = self.is_olist_null() @@ -214,33 +257,74 @@ def evaluate_olist_change(self): self._was_olist_null = is_olist_null def get_source(self): + """ + Get source IP of multicast source + """ return self.source_ip def get_group(self): + """ + Get group IP of multicast tree + """ return self.group_ip def change(self): + """ + Trigger an update on the multicast routing table + """ with self._multicast_change: - Main.kernel.set_multicast_route(self) + self.get_kernel().set_multicast_route(self) def delete(self): + """ + Remove kernel entry + """ with self._multicast_change: for state in self.interface_state.values(): state.delete() - Main.kernel.remove_multicast_route(self) + self.get_kernel().remove_multicast_route(self) + + def get_interface_name(self, interface_id): + """ + Get interface name of interface identified by interface_id + """ + return self._kernel_entry_interface.get_interface_name(interface_id) + + def get_interface(self, interface_id): + """ + Get PIM interface + """ + return self._kernel_entry_interface.get_interface(self, interface_id) + + def get_membership_interface(self, interface_id): + """ + Get IGMP/MLD interface + """ + return self._kernel_entry_interface.get_membership_interface(self, interface_id) + def get_kernel(self): + """ + Get kernel + """ + return self._kernel_entry_interface.get_kernel() ###################################### # Interface change ####################################### def new_interface(self, index): + """ + React to a new interface that was added and in which a tree was already built + """ with self.CHANGE_STATE_LOCK: self.interface_state[index] = TreeInterfaceDownstream(self, index) self.change() self.evaluate_olist_change() def remove_interface(self, index): + """ + React to removal of an interface of a tree that was already built + """ with self.CHANGE_STATE_LOCK: #check if removed interface is root interface if self.inbound_interface_index == index: diff --git a/pimdm/tree/KernelEntryInterface.py b/pimdm/tree/KernelEntryInterface.py new file mode 100644 index 0000000..c7b22d5 --- /dev/null +++ b/pimdm/tree/KernelEntryInterface.py @@ -0,0 +1,128 @@ +from pimdm import Main +from abc import abstractmethod, ABCMeta + + +class KernelEntryInterface(metaclass=ABCMeta): + @staticmethod + @abstractmethod + def get_outbound_interfaces_indexes(kernel_tree): + """ + Get OIL of this tree + """ + pass + + @staticmethod + @abstractmethod + def get_interface_name(interface_id): + """ + Get name of interface from vif id + """ + pass + + @staticmethod + @abstractmethod + def get_interface(kernel_tree, interface_id): + """ + Get PIM interface from interface id + """ + pass + + @staticmethod + @abstractmethod + def get_membership_interface(kernel_tree, interface_id): + """ + Get IGMP/MLD interface from interface id + """ + pass + + @staticmethod + @abstractmethod + def get_kernel(): + """ + Get kernel + """ + pass + + +class KernelEntry4Interface(KernelEntryInterface): + @staticmethod + def get_outbound_interfaces_indexes(kernel_tree): + """ + Get OIL of this tree + """ + outbound_indexes = [0] * Main.kernel.MAXVIFS + for (index, state) in kernel_tree.interface_state.items(): + outbound_indexes[index] = state.is_forwarding() + return outbound_indexes + + @staticmethod + def get_interface_name(interface_id): + """ + Get name of interface from vif id + """ + return Main.kernel.vif_index_to_name_dic[interface_id] + + @staticmethod + def get_interface(kernel_tree, interface_id): + """ + Get PIM interface from interface id + """ + interface_name = kernel_tree.get_interface_name(interface_id) + return Main.interfaces.get(interface_name, None) + + @staticmethod + def get_membership_interface(kernel_tree, interface_id): + """ + Get IGMP interface from interface id + """ + interface_name = kernel_tree.get_interface_name(interface_id) + return Main.igmp_interfaces.get(interface_name, None) # type: InterfaceIGMP + + @staticmethod + def get_kernel(): + """ + Get kernel + """ + return Main.kernel + + +class KernelEntry6Interface(KernelEntryInterface): + @staticmethod + def get_outbound_interfaces_indexes(kernel_tree): + """ + Get OIL of this tree + """ + outbound_indexes = [0] * 8 + for (index, state) in kernel_tree.interface_state.items(): + outbound_indexes[index // 32] |= state.is_forwarding() << (index % 32) + return outbound_indexes + + @staticmethod + def get_interface_name(interface_id): + """ + Get name of interface from vif id + """ + return Main.kernel_v6.vif_index_to_name_dic[interface_id] + + @staticmethod + def get_interface(kernel_tree, interface_id): + """ + Get PIM interface from interface id + """ + interface_name = kernel_tree.get_interface_name(interface_id) + return Main.interfaces_v6.get(interface_name, None) + + @staticmethod + def get_membership_interface(kernel_tree, interface_id): + """ + Get MLD interface from interface id + """ + interface_name = kernel_tree.get_interface_name(interface_id) + return Main.mld_interfaces.get(interface_name, None) # type: InterfaceMLD + + @staticmethod + def get_kernel(): + """ + Get kernel + """ + return Main.kernel_v6 diff --git a/pimdm/tree/assert_.py b/pimdm/tree/assert_state.py similarity index 99% rename from pimdm/tree/assert_.py rename to pimdm/tree/assert_state.py index 8b7dd3d..27408ae 100644 --- a/pimdm/tree/assert_.py +++ b/pimdm/tree/assert_state.py @@ -112,12 +112,10 @@ def receivedPruneOrJoinOrGraft(interface: "TreeInterfaceDownstream"): """ raise NotImplementedError() - def _sendAssert_setAT(interface: "TreeInterfaceDownstream"): interface.set_assert_timer(pim_globals.ASSERT_TIME) interface.send_assert() - # Override def __str__(self) -> str: return "AssertSM:" + self.__class__.__name__ @@ -289,7 +287,6 @@ def __str__(self) -> str: return "Winner" - class LoserState(AssertStateABC): ''' I am Assert Loser (L) @@ -370,6 +367,7 @@ def _to_NoInfo(interface: "TreeInterfaceDownstream"): def __str__(self) -> str: return "Loser" + class AssertState(): NoInfo = NoInfoState() Winner = WinnerState() diff --git a/pimdm/tree/DataPacketsSocket.py b/pimdm/tree/data_packets_socket.py similarity index 50% rename from pimdm/tree/DataPacketsSocket.py rename to pimdm/tree/data_packets_socket.py index 3cc02f0..ff79e16 100644 --- a/pimdm/tree/DataPacketsSocket.py +++ b/pimdm/tree/data_packets_socket.py @@ -1,15 +1,29 @@ -import subprocess import struct import socket +import ipaddress +import subprocess from ctypes import create_string_buffer, addressof SO_ATTACH_FILTER = 26 -ETH_P_IP = 0x0800 # Internet Protocol packet +ETH_P_IP = 0x0800 # Internet Protocol packet +ETH_P_IPV6 = 0x86DD # IPv6 over bluebook + SO_RCVBUFFORCE = 33 def get_s_g_bpf_filter_code(source, group, interface_name): - #cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group) - cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group) + ip_source_version = ipaddress.ip_address(source).version + ip_group_version = ipaddress.ip_address(source).version + if ip_source_version == ip_group_version == 4: + # cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group) + cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group) + protocol = ETH_P_IP + elif ip_source_version == ip_group_version == 6: + # TODO: allow ICMPv6 echo request/echo response to be considered multicast packets + cmd = "tcpdump -ddd \"(ip6 proto not 58) and host %s and dst %s\"" % (source, group) + protocol = ETH_P_IPV6 + else: + raise Exception("Unknown IP family") + result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) bpf_filter = b'' @@ -28,10 +42,10 @@ def get_s_g_bpf_filter_code(source, group, interface_name): # Create listening socket with filters - s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP) + s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, protocol) s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog) # todo pequeno ajuste (tamanho de buffer pequeno para o caso de trafego em rajadas): #s.setsockopt(socket.SOL_SOCKET, SO_RCVBUFFORCE, 1) - s.bind((interface_name, ETH_P_IP)) + s.bind((interface_name, protocol)) return s diff --git a/pimdm/tree/globals.py b/pimdm/tree/globals.py index 79ae836..4ee9f96 100644 --- a/pimdm/tree/globals.py +++ b/pimdm/tree/globals.py @@ -7,4 +7,8 @@ SOURCE_LIFETIME = 210 T_LIMIT = 210 +HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF +HELLO_HOLD_TIME = 160 +HELLO_HOLD_TIME_TIMEOUT = 0 + ASSERT_CANCEL_METRIC = 0xFFFFFFFF \ No newline at end of file diff --git a/pimdm/tree/originator.py b/pimdm/tree/originator.py index b6c487f..a0af4bb 100644 --- a/pimdm/tree/originator.py +++ b/pimdm/tree/originator.py @@ -1,20 +1,24 @@ -from abc import ABCMeta, abstractstaticmethod +from abc import ABCMeta, abstractmethod class OriginatorStateABC(metaclass=ABCMeta): - @abstractstaticmethod + @staticmethod + @abstractmethod def recvDataMsgFromSource(tree): pass - @abstractstaticmethod + @staticmethod + @abstractmethod def SRTexpires(tree): pass - @abstractstaticmethod + @staticmethod + @abstractmethod def SATexpires(tree): pass - @abstractstaticmethod + @staticmethod + @abstractmethod def SourceNotConnected(tree): pass diff --git a/pimdm/tree/tree_if_downstream.py b/pimdm/tree/tree_if_downstream.py index 2798ddb..aae9709 100644 --- a/pimdm/tree/tree_if_downstream.py +++ b/pimdm/tree/tree_if_downstream.py @@ -1,19 +1,13 @@ -''' -Created on Jul 16, 2015 - -@author: alex -''' from threading import Timer -from pimdm.CustomTimer.RemainingTimer import RemainingTimer -from .assert_ import AssertState +from pimdm.custom_timer.RemainingTimer import RemainingTimer +from .assert_state import AssertState from .downstream_prune import DownstreamState, DownstreamStateABS from .tree_interface import TreeInterface -from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh -from pimdm.Packet.Packet import Packet -from pimdm.Packet.PacketPimHeader import PacketPimHeader +from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh +from pimdm.packet.Packet import Packet +from pimdm.packet.PacketPimHeader import PacketPimHeader import traceback import logging -from .. import Main class TreeInterfaceDownstream(TreeInterface): @@ -22,7 +16,7 @@ class TreeInterfaceDownstream(TreeInterface): def __init__(self, kernel_entry, interface_id): extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy() extra_dict_logger['vif'] = interface_id - extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id] + extra_dict_logger['interfacename'] = kernel_entry.get_interface_name(interface_id) logger = logging.LoggerAdapter(TreeInterfaceDownstream.LOGGER, extra_dict_logger) TreeInterface.__init__(self, kernel_entry, interface_id, logger) self.logger.debug('Created DownstreamInterface') diff --git a/pimdm/tree/tree_if_upstream.py b/pimdm/tree/tree_if_upstream.py index b1e3e98..90a31c5 100644 --- a/pimdm/tree/tree_if_upstream.py +++ b/pimdm/tree/tree_if_upstream.py @@ -1,22 +1,16 @@ -''' -Created on Jul 16, 2015 - -@author: alex -''' from .tree_interface import TreeInterface from .upstream_prune import UpstreamState from threading import Timer -from pimdm.CustomTimer.RemainingTimer import RemainingTimer +from pimdm.custom_timer.RemainingTimer import RemainingTimer from .globals import * import random from .metric import AssertMetric from .originator import OriginatorState, OriginatorStateABC -from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh +from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh import traceback -from . import DataPacketsSocket +from . import data_packets_socket import threading import logging -from .. import Main class TreeInterfaceUpstream(TreeInterface): @@ -25,7 +19,7 @@ class TreeInterfaceUpstream(TreeInterface): def __init__(self, kernel_entry, interface_id): extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy() extra_dict_logger['vif'] = interface_id - extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id] + extra_dict_logger['interfacename'] = kernel_entry.get_interface_name(interface_id) logger = logging.LoggerAdapter(TreeInterfaceUpstream.LOGGER, extra_dict_logger) TreeInterface.__init__(self, kernel_entry, interface_id, logger) @@ -47,15 +41,16 @@ def __init__(self, kernel_entry, interface_id): if self.is_S_directly_conn(): self._graft_prune_state.sourceIsNowDirectConnect(self) - if self.get_interface().is_state_refresh_enabled(): + interface = self.get_interface() + if interface is not None and interface.is_state_refresh_enabled(): self._originator_state.recvDataMsgFromSource(self) # TODO TESTE SOCKET RECV DATA PCKTS self.socket_is_enabled = True - (s,g) = self.get_tree_id() - interface_name = self.get_interface().interface_name - self.socket_pkt = DataPacketsSocket.get_s_g_bpf_filter_code(s, g, interface_name) + (s, g) = self.get_tree_id() + interface_name = self.get_interface_name() + self.socket_pkt = data_packets_socket.get_s_g_bpf_filter_code(s, g, interface_name) # run receive method in background receive_thread = threading.Thread(target=self.socket_recv) @@ -182,8 +177,10 @@ def source_active_timeout(self): def recv_data_msg(self): if not self.is_prune_limit_timer_running() and not self.is_S_directly_conn() and self.is_olist_null(): self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self) - elif self.is_S_directly_conn() and self.get_interface().is_state_refresh_enabled(): - self._originator_state.recvDataMsgFromSource(self) + elif self.is_S_directly_conn(): + interface = self.get_interface() + if interface is not None and interface.is_state_refresh_enabled(): + self._originator_state.recvDataMsgFromSource(self) def recv_join_msg(self, upstream_neighbor_address): diff --git a/pimdm/tree/tree_interface.py b/pimdm/tree/tree_interface.py index 1afa755..32adf54 100644 --- a/pimdm/tree/tree_interface.py +++ b/pimdm/tree/tree_interface.py @@ -1,31 +1,26 @@ -''' -Created on Jul 16, 2015 - -@author: alex -''' from abc import ABCMeta, abstractmethod -from .. import Main from threading import RLock import traceback from .downstream_prune import DownstreamState -from .assert_ import AssertState, AssertStateABC +from .assert_state import AssertState, AssertStateABC -from pimdm.Packet.PacketPimGraft import PacketPimGraft -from pimdm.Packet.PacketPimGraftAck import PacketPimGraftAck -from pimdm.Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup -from pimdm.Packet.PacketPimHeader import PacketPimHeader -from pimdm.Packet.Packet import Packet +from pimdm.packet.PacketPimGraft import PacketPimGraft +from pimdm.packet.PacketPimGraftAck import PacketPimGraftAck +from pimdm.packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup +from pimdm.packet.PacketPimHeader import PacketPimHeader +from pimdm.packet.Packet import Packet -from pimdm.Packet.PacketPimJoinPrune import PacketPimJoinPrune -from pimdm.Packet.PacketPimAssert import PacketPimAssert -from pimdm.Packet.PacketPimStateRefresh import PacketPimStateRefresh +from pimdm.packet.PacketPimJoinPrune import PacketPimJoinPrune +from pimdm.packet.PacketPimAssert import PacketPimAssert +from pimdm.packet.PacketPimStateRefresh import PacketPimStateRefresh from .metric import AssertMetric from threading import Timer from .local_membership import LocalMembership -from .globals import * +from .globals import T_LIMIT import logging + class TreeInterface(metaclass=ABCMeta): def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter): self._kernel_entry = kernel_entry @@ -36,9 +31,8 @@ def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter): # Local Membership State try: - interface_name = Main.kernel.vif_index_to_name_dic[interface_id] - igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP - group_state = igmp_interface.interface_state.get_group_state(kernel_entry.group_ip) + membership_interface = self.get_membership_interface() + group_state = membership_interface.interface_state.get_group_state(kernel_entry.group_ip) #self._igmp_has_members = group_state.add_multicast_routing_entry(self) igmp_has_members = group_state.add_multicast_routing_entry(self) self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo @@ -60,8 +54,7 @@ def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter): # Received prune hold time self._received_prune_holdtime = None - self._igmp_lock = RLock() - + self._membership_lock = RLock() ############################################ # Set ASSERT State @@ -90,7 +83,6 @@ def set_assert_winner_metric(self, new_assert_metric: AssertMetric): finally: self._assert_winner_metric = new_assert_metric - ############################################ # ASSERT Timer ############################################ @@ -106,7 +98,6 @@ def clear_assert_timer(self): def assert_timeout(self): self._assert_state.assertTimerExpires(self) - ########################################### # Recv packets ########################################### @@ -145,7 +136,6 @@ def recv_graft_ack_msg(self, source_ip_of_graft_ack): def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator): self.recv_assert_msg(received_metric) - ###################################### # Send messages ###################################### @@ -163,7 +153,6 @@ def send_graft(self): traceback.print_exc() return - def send_graft_ack(self, ip_sender): print("send graft ack") try: @@ -177,7 +166,6 @@ def send_graft_ack(self, ip_sender): traceback.print_exc() return - def send_prune(self, holdtime=None): if holdtime is None: holdtime = T_LIMIT @@ -195,7 +183,6 @@ def send_prune(self, holdtime=None): traceback.print_exc() return - def send_pruneecho(self): holdtime = T_LIMIT try: @@ -210,7 +197,6 @@ def send_pruneecho(self): traceback.print_exc() return - def send_join(self): print("send join") @@ -225,7 +211,6 @@ def send_join(self): traceback.print_exc() return - def send_assert(self): print("send assert") @@ -240,7 +225,6 @@ def send_assert(self): traceback.print_exc() return - def send_assert_cancel(self): print("send assert cancel") @@ -254,7 +238,6 @@ def send_assert_cancel(self): traceback.print_exc() return - def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh): pass @@ -282,9 +265,8 @@ def delete(self, change_type_interface=False): (s, g) = self.get_tree_id() # unsubscribe igmp information try: - interface_name = Main.kernel.vif_index_to_name_dic[self._interface_id] - igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP - group_state = igmp_interface.interface_state.get_group_state(g) + membership_interface = self.get_membership_interface() + group_state = membership_interface.interface_state.get_group_state(g) group_state.remove_multicast_routing_entry(self) except: pass @@ -306,29 +288,29 @@ def is_olist_null(self): def evaluate_ingroup(self): self._kernel_entry.evaluate_olist_change() - ############################################################# # Local Membership (IGMP) ############################################################ - def notify_igmp(self, has_members: bool): + def notify_membership(self, has_members: bool): with self.get_state_lock(): - with self._igmp_lock: + with self._membership_lock: if has_members != self._local_membership_state.has_members(): self._local_membership_state = LocalMembership.Include if has_members else LocalMembership.NoInfo self.change_tree() self.evaluate_ingroup() - def igmp_has_members(self): - with self._igmp_lock: + with self._membership_lock: return self._local_membership_state.has_members() + def get_interface_name(self): + return self._kernel_entry.get_interface_name(self._interface_id) + def get_interface(self): - kernel = Main.kernel - interface_name = kernel.vif_index_to_name_dic[self._interface_id] - interface = Main.interfaces[interface_name] - return interface + return self._kernel_entry.get_interface(self._interface_id) + def get_membership_interface(self): + return self._kernel_entry.get_membership_interface(self._interface_id) def get_ip(self): ip = self.get_interface().get_ip() @@ -353,9 +335,6 @@ def get_state_lock(self): def is_downstream(self): raise NotImplementedError() - - - # obtain ip of RPF'(S) def get_neighbor_RPF(self): ''' @@ -375,8 +354,6 @@ def set_receceived_prune_holdtime(self, holdtime): def get_received_prune_holdtime(self): return self._received_prune_holdtime - - ################################################### # ASSERT ################################################### diff --git a/pimdm/utils.py b/pimdm/utils.py index 9ac0846..a611ec3 100644 --- a/pimdm/utils.py +++ b/pimdm/utils.py @@ -1,29 +1,4 @@ import array -''' -import struct -if struct.pack("H",1) == "\x00\x01": # big endian - def checksum(pkt): - if len(pkt) % 2 == 1: - pkt += "\0" - s = sum(array.array("H", pkt)) - s = (s >> 16) + (s & 0xffff) - s += s >> 16 - s = ~s - return s & 0xffff -else: - def checksum(pkt): - if len(pkt) % 2 == 1: - pkt += "\0" - s = sum(array.array("H", pkt)) - s = (s >> 16) + (s & 0xffff) - s += s >> 16 - s = ~s - return (((s>>8)&0xff)|s<<8) & 0xffff -''' - -HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF -HELLO_HOLD_TIME = 160 -HELLO_HOLD_TIME_TIMEOUT = 0 def checksum(pkt: bytes) -> bytes: @@ -36,35 +11,6 @@ def checksum(pkt: bytes) -> bytes: return (((s >> 8) & 0xff) | s << 8) & 0xffff -import ctypes -import ctypes.util - -libc = ctypes.CDLL(ctypes.util.find_library('c')) - - -def if_nametoindex(name): - if not isinstance(name, str): - raise TypeError('name must be a string.') - ret = libc.if_nametoindex(name) - if not ret: - raise RuntimeError("Invalid Name") - return ret - - -def if_indextoname(index): - if not isinstance(index, int): - raise TypeError('index must be an int.') - libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p] - libc.if_indextoname.restype = ctypes.c_char_p - - ifname = ctypes.create_string_buffer(32) - ifname = libc.if_indextoname(index, ifname) - if not ifname: - raise RuntimeError ("Inavlid Index") - return ifname.decode("utf-8") - - - # obtain TYPE_CHECKING (for type hinting) try: from typing import TYPE_CHECKING diff --git a/setup.py b/setup.py index 8724c5c..62133c4 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,8 @@ description="PIM-DM protocol", long_description=open("README.md", "r").read(), long_description_content_type="text/markdown", - keywords="PIM-DM Multicast Routing Protocol Dense-Mode Router RFC3973", - version="1.0.4.2", + keywords="PIM-DM Multicast Routing Protocol Dense-Mode Router RFC3973 IPv4 IPv6", + version="1.1", url="http://github.com/pedrofran12/pim_dm", author="Pedro Oliveira", author_email="pedro.francisco.oliveira@tecnico.ulisboa.pt", @@ -38,7 +38,6 @@ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.2", "Programming Language :: Python :: 3.3", "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", @@ -46,5 +45,5 @@ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", ], - python_requires=">=3.2", + python_requires=">=3.3", )