diff --git a/pyroute2/iproute/linux.py b/pyroute2/iproute/linux.py index e9b08c36c..f85e40bc1 100644 --- a/pyroute2/iproute/linux.py +++ b/pyroute2/iproute/linux.py @@ -4,9 +4,9 @@ import time import warnings from functools import partial -from socket import AF_INET, AF_UNSPEC +from socket import AF_INET, AF_INET6, AF_UNSPEC -from pyroute2.common import basestring +from pyroute2.common import AF_MPLS, basestring from pyroute2.config import AF_BRIDGE from pyroute2.netlink import NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST, NLMSG_ERROR from pyroute2.netlink.core import SyncAPI @@ -53,7 +53,13 @@ RTMGRP_DEFAULTS, RTMGRP_IPV4_IFADDR, RTMGRP_IPV4_ROUTE, + RTMGRP_IPV4_RULE, + RTMGRP_IPV6_IFADDR, + RTMGRP_IPV6_ROUTE, + RTMGRP_IPV6_RULE, RTMGRP_LINK, + RTMGRP_MPLS_ROUTE, + RTMGRP_NEIGH, ndmsg, ) from pyroute2.netlink.rtnl.fibmsg import fibmsg @@ -115,18 +121,19 @@ def get_dump_filter(mode, command, query): if command != 'dump': return RequestProcessor(), query if 'match' in query: - return query.pop('match'), query - else: - new_query = {} - if 'family' in query: - new_query['family'] = query.pop('family') - dump_filter = RequestProcessor(context=query, prime=query) - for rf in query.pop( - 'dump_filter', get_default_request_filters(mode, command) - ): - dump_filter.add_filter(rf) - dump_filter.finalize() - return dump_filter, new_query + query = query['match'] + if callable(query): + return query, {} + new_query = {} + if 'family' in query: + new_query['family'] = query.pop('family') + dump_filter = RequestProcessor(context=query, prime=query) + for rf in query.pop( + 'dump_filter', get_default_request_filters(mode, command) + ): + dump_filter.add_filter(rf) + dump_filter.finalize() + return dump_filter, new_query def get_arguments_processor(mode, command, query, parameters=None): @@ -189,26 +196,6 @@ class RTNL_API: ipr.link('set', index=dev, state='up') ''' - def match_one_message(self, dump_filter, msg): - if hasattr(dump_filter, '__call__'): - return dump_filter(msg) - elif isinstance(dump_filter, dict): - matches = [] - for key in dump_filter: - # get the attribute - if isinstance(key, str): - nkey = (key,) - elif isinstance(key, tuple): - nkey = key - else: - continue - value = msg.get_nested(*nkey) - if value is not None and callable(dump_filter[key]): - matches.append(dump_filter[key](value)) - else: - matches.append(dump_filter[key] == value) - return all(matches) - def filter_messages(self, dump_filter, msgs): ''' Filter messages using `dump_filter`. The filter might be a @@ -288,36 +275,31 @@ async def dump(self, groups=None): # and the code may be run on BSD systems as well, though # BSD systems have only subset of the API # - # if self.uname[0] == 'OpenBSD': - # groups_map = { - # 1: [ - # self.get_links, - # self.get_addr, - # self.get_neighbours, - # self.get_routes, - # ] - # } - # else: - # groups_map = { - # RTMGRP_LINK: [ - # self.get_links, - # self.get_vlans, - # partial(self.fdb, 'dump'), - # ], - # RTMGRP_IPV4_IFADDR: [partial(self.get_addr, family=AF_INET)], - # RTMGRP_IPV6_IFADDR: [partial(self.get_addr, family=AF_INET6)], - # RTMGRP_NEIGH: [self.get_neighbours], - # RTMGRP_IPV4_ROUTE: [partial(self.get_routes, family=AF_INET)], - # RTMGRP_IPV6_ROUTE: [partial(self.get_routes, family=AF_INET6)], - # RTMGRP_MPLS_ROUTE: [partial(self.get_routes, family=AF_MPLS)], - # RTMGRP_IPV4_RULE: [partial(self.get_rules, family=AF_INET)], - # RTMGRP_IPV6_RULE: [partial(self.get_rules, family=AF_INET6)], - # } - groups_map = { - RTMGRP_LINK: [partial(self.link, 'dump')], - RTMGRP_IPV4_IFADDR: [partial(self.addr, 'dump', family=AF_INET)], - RTMGRP_IPV4_ROUTE: [partial(self.route, 'dump', family=AF_INET)], - } + if self.uname[0] == 'OpenBSD': + groups_map = { + 1: [ + self.get_links, + self.get_addr, + self.get_neighbours, + self.get_routes, + ] + } + else: + groups_map = { + RTMGRP_LINK: [ + self.get_links, + self.get_vlans, + partial(self.fdb, 'dump'), + ], + RTMGRP_IPV4_IFADDR: [partial(self.get_addr, family=AF_INET)], + RTMGRP_IPV6_IFADDR: [partial(self.get_addr, family=AF_INET6)], + RTMGRP_NEIGH: [self.get_neighbours], + RTMGRP_IPV4_ROUTE: [partial(self.get_routes, family=AF_INET)], + RTMGRP_IPV6_ROUTE: [partial(self.get_routes, family=AF_INET6)], + RTMGRP_MPLS_ROUTE: [partial(self.get_routes, family=AF_MPLS)], + RTMGRP_IPV4_RULE: [partial(self.get_rules, family=AF_INET)], + RTMGRP_IPV6_RULE: [partial(self.get_rules, family=AF_INET6)], + } for group, methods in groups_map.items(): if group & (groups if groups is not None else self.groups): for method in methods: @@ -430,7 +412,7 @@ async def get_qdiscs(self, index=None): ''' return await self.tc('dump') - def get_filters(self, index=0, handle=0, parent=0): + async def get_filters(self, index=0, handle=0, parent=0): ''' Get filters for specified interface, handle and parent. ''' @@ -439,16 +421,20 @@ def get_filters(self, index=0, handle=0, parent=0): msg['index'] = index msg['handle'] = transform_handle(handle) msg['parent'] = transform_handle(parent) - return tuple(self.nlm_request(msg, RTM_GETTFILTER)) + request = NetlinkRequest(self, msg, msg_type=RTM_GETTFILTER) + await request.send() + return request.response() - def get_classes(self, index=0): + async def get_classes(self, index=0): ''' Get classes for specified interface. ''' msg = tcmsg() msg['family'] = AF_UNSPEC msg['index'] = index - return tuple(self.nlm_request(msg, RTM_GETTCLASS)) + request = NetlinkRequest(self, msg, msg_type=RTM_GETTCLASS) + await request.send() + return request.response() async def get_vlans(self, **kwarg): ''' @@ -820,7 +806,9 @@ async def get_default_routes(self, family=AF_UNSPEC, table=DEFAULT_TABLE): ''' msg = rtmsg() msg['family'] = family - dump_filter, _ = get_dump_filter('route', 'dump', {'table': table} if table is not None else {}) + dump_filter, _ = get_dump_filter( + 'route', 'dump', {'table': table} if table is not None else {} + ) request = NetlinkRequest( self, msg, @@ -2461,6 +2449,8 @@ def __getattr__(self, name): ] async_dump_methods = [ 'get_qdiscs', + 'get_filters', + 'get_classes', 'get_links', 'get_addr', 'get_neighbours', @@ -2479,7 +2469,11 @@ async def collect_dump(): async def collect_op(): return await symbol(*argv, **kwarg) - if len(argv) > 0 and argv[0].startswith('dump'): + if ( + len(argv) > 0 + and isinstance(argv[0], str) + and argv[0].startswith('dump') + ): task = collect_dump else: task = collect_op diff --git a/pyroute2/ndb/objects/__init__.py b/pyroute2/ndb/objects/__init__.py index aafc4a2e3..87e20d6ed 100644 --- a/pyroute2/ndb/objects/__init__.py +++ b/pyroute2/ndb/objects/__init__.py @@ -253,8 +253,8 @@ def dump(cls, view): ) @classmethod - def spec_normalize(cls, processed, spec): - return processed + def spec_normalize(cls, spec): + return spec @staticmethod def key_load_context(key, context): @@ -350,13 +350,15 @@ def __init__( def new_spec(cls, spec, context=None, localhost=None): if isinstance(spec, Record): spec = spec._as_dict() - rp = RequestProcessor(context=spec, prime=spec) + spec = cls.spec_normalize(spec) + rp = RequestProcessor(context=spec) rp.add_filter(cls.field_filter()) + rp.update(spec) if isinstance(context, dict): rp.update(context) if 'target' not in rp and localhost is not None: rp['target'] = localhost - return cls.spec_normalize(rp, spec) + return rp @staticmethod def resolve(view, spec, fields, policy=RSLV_IGNORE): diff --git a/pyroute2/ndb/objects/address.py b/pyroute2/ndb/objects/address.py index 29cc0821b..421e2e502 100644 --- a/pyroute2/ndb/objects/address.py +++ b/pyroute2/ndb/objects/address.py @@ -319,7 +319,7 @@ def compare_record(left, right): ) @classmethod - def spec_normalize(cls, processed, spec): + def spec_normalize(cls, spec): ''' Address key normalization:: @@ -328,8 +328,8 @@ def spec_normalize(cls, processed, spec): "prefixlen": 24} ''' if isinstance(spec, str): - processed['address'] = spec - return processed + return {'address': spec} + return spec def key_repr(self): return '%s/%s %s/%s' % ( diff --git a/pyroute2/ndb/objects/interface.py b/pyroute2/ndb/objects/interface.py index 8d73c9a98..f67e337f5 100644 --- a/pyroute2/ndb/objects/interface.py +++ b/pyroute2/ndb/objects/interface.py @@ -879,7 +879,7 @@ def __setitem__(self, key, value): super(Interface, self).__setitem__(key, value) @classmethod - def spec_normalize(cls, processed, spec): + def spec_normalize(cls, spec): ''' Interface key normalization:: @@ -889,10 +889,10 @@ def spec_normalize(cls, processed, spec): ''' if isinstance(spec, basestring): - processed['ifname'] = spec + return {'ifname': spec} elif isinstance(spec, int): - processed['index'] = spec - return processed + return {'index': spec} + return spec def complete_key(self, key): if isinstance(key, dict): diff --git a/pyroute2/ndb/objects/netns.py b/pyroute2/ndb/objects/netns.py index 8675fd453..3022b5e70 100644 --- a/pyroute2/ndb/objects/netns.py +++ b/pyroute2/ndb/objects/netns.py @@ -45,17 +45,17 @@ def __init__(self, *argv, **kwarg): super(NetNS, self).__init__(*argv, **kwarg) @classmethod - def spec_normalize(cls, processed, spec): + def spec_normalize(cls, spec): if isinstance(spec, basestring): - processed['path'] = spec - path = netns._get_netnspath(processed['path']) + spec = {'path': spec} + path = netns._get_netnspath(spec['path']) # on Python3 _get_netnspath() returns bytes, not str, so # we have to decode it here in order to avoid issues with # cache keys and DB inserts if hasattr(path, 'decode'): path = path.decode('utf-8') - processed['path'] = path - return processed + spec['path'] = path + return spec def __setitem__(self, key, value): if self.state == 'system': diff --git a/pyroute2/ndb/objects/route.py b/pyroute2/ndb/objects/route.py index a2ae3e3a5..132cfbc16 100644 --- a/pyroute2/ndb/objects/route.py +++ b/pyroute2/ndb/objects/route.py @@ -556,10 +556,10 @@ def dump(cls, view): yield record @classmethod - def spec_normalize(cls, processed, spec): + def spec_normalize(cls, spec): if isinstance(spec, basestring): - processed['dst'] = spec - return processed + return {'dst': spec} + return spec @classmethod def compare_record(self, left, right): diff --git a/pyroute2/netlink/core.py b/pyroute2/netlink/core.py index 81e7557b9..cd1a48788 100644 --- a/pyroute2/netlink/core.py +++ b/pyroute2/netlink/core.py @@ -590,7 +590,14 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __getattr__(self, key): - if key in ('pid', 'send', 'recv', 'sendto'): + if key in ( + 'pid', + 'send', + 'recv', + 'sendto', + 'register_policy', + 'unregister_policy', + ): return getattr(self.asyncore, key) raise AttributeError(key) diff --git a/pyroute2/netlink/nlsocket.py b/pyroute2/netlink/nlsocket.py index e09dafd84..cbf64f95e 100644 --- a/pyroute2/netlink/nlsocket.py +++ b/pyroute2/netlink/nlsocket.py @@ -475,13 +475,9 @@ def match_one_message(self, msg): matches = [] for key in self.dump_filter: # get the attribute - if isinstance(key, str): - nkey = (key,) - elif isinstance(key, tuple): - nkey = key - else: + if not isinstance(key, (str, tuple)): continue - value = msg.get(*nkey) + value = msg.get(key) if value is not None and callable(self.dump_filter[key]): matches.append(self.dump_filter[key](value)) else: diff --git a/pyroute2/requests/main.py b/pyroute2/requests/main.py index 65ada9795..f240027f2 100644 --- a/pyroute2/requests/main.py +++ b/pyroute2/requests/main.py @@ -17,6 +17,7 @@ def __init__(self, context=None, prime=None): self.reset_filters() self.reset_mark() self.parameters = {} + prime = {} if prime is None else prime self.context = ( context if isinstance(context, (dict, weakref.ProxyType)) else {} ) diff --git a/pyroute2/requests/tc.py b/pyroute2/requests/tc.py index 78ffa26c9..49ea54819 100644 --- a/pyroute2/requests/tc.py +++ b/pyroute2/requests/tc.py @@ -45,9 +45,14 @@ def finalize(self, context): 'parent', getattr(plugin, 'parent', 0) ) if set(context.keys()) > set(('kind', 'index', 'handle')): + get_parameters = None if self.command[-5:] == 'class': - context['options'] = plugin.get_class_parameters(context) + get_parameters = getattr( + plugin, 'get_class_parameters', None + ) else: - context['options'] = plugin.get_parameters(context) + get_parameters = getattr(plugin, 'get_parameters', None) + if get_parameters is not None: + context['options'] = get_parameters(dict(context)) if hasattr(plugin, 'fix_request'): plugin.fix_request(context) diff --git a/tests/test_linux/pr2test/tools.py b/tests/test_linux/pr2test/tools.py index 53a5b7369..148fdcee9 100644 --- a/tests/test_linux/pr2test/tools.py +++ b/tests/test_linux/pr2test/tools.py @@ -1,5 +1,4 @@ -from pyroute2.iproute.linux import IPRoute -from pyroute2.nslink.nslink import NetNS +from pyroute2 import IPRoute, NetNS def interface_exists(netns=None, *argv, **kwarg): diff --git a/tests/test_linux/test_ipdb.py b/tests/test_linux/test_ipdb.py deleted file mode 100644 index c5cdac227..000000000 --- a/tests/test_linux/test_ipdb.py +++ /dev/null @@ -1,88 +0,0 @@ -import pytest -from pr2test.marks import require_root - -from pyroute2 import IPDB - -pytestmark = [require_root()] - - -@pytest.fixture -def ictx(context): - context.ipdb = IPDB(deprecation_warning=False) - yield context - context.ipdb.release() - - -def test_interface_dummy(ictx): - ifname = ictx.new_ifname - ipaddr = ictx.new_ipaddr - interface = ictx.ipdb.create(ifname=ifname, kind='dummy') - interface.up() - interface.add_ip(f'{ipaddr}/24') - interface.commit() - - ictx.ndb.interfaces.wait(action='add', ifname=ifname, timeout=3) - ictx.ndb.addresses.wait(action='add', address=ipaddr, timeout=3) - assert ictx.ndb.interfaces[ifname]['state'] == 'up' - assert ( - ictx.ndb.addresses.wait(action='add', address=ipaddr, prefixlen=24)[ - 'index' - ] - == interface['index'] - ) - - interface.del_ip(f'{ipaddr}/24') - interface.commit() - - ictx.ndb.addresses.wait( - action='remove', address=ipaddr, prefixlen=24, timeout=3 - ) - - -def test_interface_veth(ictx): - netns = ictx.new_nsname - ictx.ndb.sources.add(netns=netns) - v0 = ictx.new_ifname - v1 = ictx.new_ifname - - veth0 = ictx.ipdb.create(ifname=v0, kind='veth', peer=v1) - veth0.up() - veth0.commit() - - veth1 = ictx.ipdb.interfaces[v1] - veth1['net_ns_fd'] = netns - veth1.commit() - - ictx.ndb.interfaces.wait(ifname=v0, target='localhost', timeout=3) - ictx.ndb.interfaces.wait(ifname=v1, target=netns, timeout=3) - - -def test_interface_bridge(ictx): - ifname = ictx.new_ifname - - with ictx.ipdb.create(ifname=ifname, kind='bridge') as i: - i.up() - i['address'] = '00:11:22:33:44:55' - i['br_stp_state'] = 1 - i['br_forward_delay'] = 1000 - - i = ictx.ndb.interfaces.wait(ifname=ifname, timeout=3) - assert i['state'] == 'up' - assert i['address'] == '00:11:22:33:44:55' - assert i['br_stp_state'] == 1 - assert i['br_forward_delay'] == 1000 - - -def test_route_basic(ictx): - ipaddr = ictx.new_ipaddr - gateway = ictx.new_ipaddr - net = ictx.new_ip4net - ifname = ictx.default_interface.ifname - - with ictx.ipdb.interfaces[ifname] as i: - i.up() - i.add_ip(f'{ipaddr}/24') - - ictx.ipdb.routes.add( - gateway=gateway, dst=f'{net.network}/{net.netmask}' - ).commit() diff --git a/tests/test_linux/test_remote.py b/tests/test_linux/test_remote.py deleted file mode 100644 index 73962abb7..000000000 --- a/tests/test_linux/test_remote.py +++ /dev/null @@ -1,14 +0,0 @@ -from pr2test.context_manager import skip_if_not_supported - -from pyroute2 import IPRoute, RemoteIPRoute - - -@skip_if_not_supported -def test_links(): - with IPRoute() as ipr: - links1 = set([x.get_attr('IFLA_IFNAME') for x in ipr.get_links()]) - - with RemoteIPRoute() as ipr: - links2 = set([x.get_attr('IFLA_IFNAME') for x in ipr.get_links()]) - - assert links1 == links2