Skip to content

Commit

Permalink
requests: filter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
svinota committed Dec 11, 2024
1 parent 7133bf1 commit c7d4d1a
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 202 deletions.
134 changes: 64 additions & 70 deletions pyroute2/iproute/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import time
import warnings
from functools import partial
from socket import AF_INET, AF_UNSPEC
from socket import AF_INET, AF_INET6, AF_UNSPEC

from pyroute2.common import basestring
from pyroute2.common import AF_MPLS, basestring
from pyroute2.config import AF_BRIDGE
from pyroute2.netlink import NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST, NLMSG_ERROR
from pyroute2.netlink.core import SyncAPI
Expand Down Expand Up @@ -53,7 +53,13 @@
RTMGRP_DEFAULTS,
RTMGRP_IPV4_IFADDR,
RTMGRP_IPV4_ROUTE,
RTMGRP_IPV4_RULE,
RTMGRP_IPV6_IFADDR,
RTMGRP_IPV6_ROUTE,
RTMGRP_IPV6_RULE,
RTMGRP_LINK,
RTMGRP_MPLS_ROUTE,
RTMGRP_NEIGH,
ndmsg,
)
from pyroute2.netlink.rtnl.fibmsg import fibmsg
Expand Down Expand Up @@ -115,18 +121,19 @@ def get_dump_filter(mode, command, query):
if command != 'dump':
return RequestProcessor(), query
if 'match' in query:
return query.pop('match'), query
else:
new_query = {}
if 'family' in query:
new_query['family'] = query.pop('family')
dump_filter = RequestProcessor(context=query, prime=query)
for rf in query.pop(
'dump_filter', get_default_request_filters(mode, command)
):
dump_filter.add_filter(rf)
dump_filter.finalize()
return dump_filter, new_query
query = query['match']
if callable(query):
return query, {}
new_query = {}
if 'family' in query:
new_query['family'] = query.pop('family')
dump_filter = RequestProcessor(context=query, prime=query)
for rf in query.pop(
'dump_filter', get_default_request_filters(mode, command)
):
dump_filter.add_filter(rf)
dump_filter.finalize()
return dump_filter, new_query


def get_arguments_processor(mode, command, query, parameters=None):
Expand Down Expand Up @@ -189,26 +196,6 @@ class RTNL_API:
ipr.link('set', index=dev, state='up')
'''

def match_one_message(self, dump_filter, msg):
if hasattr(dump_filter, '__call__'):
return dump_filter(msg)
elif isinstance(dump_filter, dict):
matches = []
for key in dump_filter:
# get the attribute
if isinstance(key, str):
nkey = (key,)
elif isinstance(key, tuple):
nkey = key
else:
continue
value = msg.get_nested(*nkey)
if value is not None and callable(dump_filter[key]):
matches.append(dump_filter[key](value))
else:
matches.append(dump_filter[key] == value)
return all(matches)

def filter_messages(self, dump_filter, msgs):
'''
Filter messages using `dump_filter`. The filter might be a
Expand Down Expand Up @@ -288,36 +275,31 @@ async def dump(self, groups=None):
# and the code may be run on BSD systems as well, though
# BSD systems have only subset of the API
#
# if self.uname[0] == 'OpenBSD':
# groups_map = {
# 1: [
# self.get_links,
# self.get_addr,
# self.get_neighbours,
# self.get_routes,
# ]
# }
# else:
# groups_map = {
# RTMGRP_LINK: [
# self.get_links,
# self.get_vlans,
# partial(self.fdb, 'dump'),
# ],
# RTMGRP_IPV4_IFADDR: [partial(self.get_addr, family=AF_INET)],
# RTMGRP_IPV6_IFADDR: [partial(self.get_addr, family=AF_INET6)],
# RTMGRP_NEIGH: [self.get_neighbours],
# RTMGRP_IPV4_ROUTE: [partial(self.get_routes, family=AF_INET)],
# RTMGRP_IPV6_ROUTE: [partial(self.get_routes, family=AF_INET6)],
# RTMGRP_MPLS_ROUTE: [partial(self.get_routes, family=AF_MPLS)],
# RTMGRP_IPV4_RULE: [partial(self.get_rules, family=AF_INET)],
# RTMGRP_IPV6_RULE: [partial(self.get_rules, family=AF_INET6)],
# }
groups_map = {
RTMGRP_LINK: [partial(self.link, 'dump')],
RTMGRP_IPV4_IFADDR: [partial(self.addr, 'dump', family=AF_INET)],
RTMGRP_IPV4_ROUTE: [partial(self.route, 'dump', family=AF_INET)],
}
if self.uname[0] == 'OpenBSD':
groups_map = {
1: [
self.get_links,
self.get_addr,
self.get_neighbours,
self.get_routes,
]
}
else:
groups_map = {
RTMGRP_LINK: [
self.get_links,
self.get_vlans,
partial(self.fdb, 'dump'),
],
RTMGRP_IPV4_IFADDR: [partial(self.get_addr, family=AF_INET)],
RTMGRP_IPV6_IFADDR: [partial(self.get_addr, family=AF_INET6)],
RTMGRP_NEIGH: [self.get_neighbours],
RTMGRP_IPV4_ROUTE: [partial(self.get_routes, family=AF_INET)],
RTMGRP_IPV6_ROUTE: [partial(self.get_routes, family=AF_INET6)],
RTMGRP_MPLS_ROUTE: [partial(self.get_routes, family=AF_MPLS)],
RTMGRP_IPV4_RULE: [partial(self.get_rules, family=AF_INET)],
RTMGRP_IPV6_RULE: [partial(self.get_rules, family=AF_INET6)],
}
for group, methods in groups_map.items():
if group & (groups if groups is not None else self.groups):
for method in methods:
Expand Down Expand Up @@ -430,7 +412,7 @@ async def get_qdiscs(self, index=None):
'''
return await self.tc('dump')

def get_filters(self, index=0, handle=0, parent=0):
async def get_filters(self, index=0, handle=0, parent=0):
'''
Get filters for specified interface, handle and parent.
'''
Expand All @@ -439,16 +421,20 @@ def get_filters(self, index=0, handle=0, parent=0):
msg['index'] = index
msg['handle'] = transform_handle(handle)
msg['parent'] = transform_handle(parent)
return tuple(self.nlm_request(msg, RTM_GETTFILTER))
request = NetlinkRequest(self, msg, msg_type=RTM_GETTFILTER)
await request.send()
return request.response()

def get_classes(self, index=0):
async def get_classes(self, index=0):
'''
Get classes for specified interface.
'''
msg = tcmsg()
msg['family'] = AF_UNSPEC
msg['index'] = index
return tuple(self.nlm_request(msg, RTM_GETTCLASS))
request = NetlinkRequest(self, msg, msg_type=RTM_GETTCLASS)
await request.send()
return request.response()

async def get_vlans(self, **kwarg):
'''
Expand Down Expand Up @@ -820,7 +806,9 @@ async def get_default_routes(self, family=AF_UNSPEC, table=DEFAULT_TABLE):
'''
msg = rtmsg()
msg['family'] = family
dump_filter, _ = get_dump_filter('route', 'dump', {'table': table} if table is not None else {})
dump_filter, _ = get_dump_filter(
'route', 'dump', {'table': table} if table is not None else {}
)
request = NetlinkRequest(
self,
msg,
Expand Down Expand Up @@ -2461,6 +2449,8 @@ def __getattr__(self, name):
]
async_dump_methods = [
'get_qdiscs',
'get_filters',
'get_classes',
'get_links',
'get_addr',
'get_neighbours',
Expand All @@ -2479,7 +2469,11 @@ async def collect_dump():
async def collect_op():
return await symbol(*argv, **kwarg)

if len(argv) > 0 and argv[0].startswith('dump'):
if (
len(argv) > 0
and isinstance(argv[0], str)
and argv[0].startswith('dump')
):
task = collect_dump
else:
task = collect_op
Expand Down
10 changes: 6 additions & 4 deletions pyroute2/ndb/objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def dump(cls, view):
)

@classmethod
def spec_normalize(cls, processed, spec):
return processed
def spec_normalize(cls, spec):
return spec

@staticmethod
def key_load_context(key, context):
Expand Down Expand Up @@ -350,13 +350,15 @@ def __init__(
def new_spec(cls, spec, context=None, localhost=None):
if isinstance(spec, Record):
spec = spec._as_dict()
rp = RequestProcessor(context=spec, prime=spec)
spec = cls.spec_normalize(spec)
rp = RequestProcessor(context=spec)
rp.add_filter(cls.field_filter())
rp.update(spec)
if isinstance(context, dict):
rp.update(context)
if 'target' not in rp and localhost is not None:
rp['target'] = localhost
return cls.spec_normalize(rp, spec)
return rp

@staticmethod
def resolve(view, spec, fields, policy=RSLV_IGNORE):
Expand Down
6 changes: 3 additions & 3 deletions pyroute2/ndb/objects/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def compare_record(left, right):
)

@classmethod
def spec_normalize(cls, processed, spec):
def spec_normalize(cls, spec):
'''
Address key normalization::
Expand All @@ -328,8 +328,8 @@ def spec_normalize(cls, processed, spec):
"prefixlen": 24}
'''
if isinstance(spec, str):
processed['address'] = spec
return processed
return {'address': spec}
return spec

def key_repr(self):
return '%s/%s %s/%s' % (
Expand Down
8 changes: 4 additions & 4 deletions pyroute2/ndb/objects/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ def __setitem__(self, key, value):
super(Interface, self).__setitem__(key, value)

@classmethod
def spec_normalize(cls, processed, spec):
def spec_normalize(cls, spec):
'''
Interface key normalization::
Expand All @@ -889,10 +889,10 @@ def spec_normalize(cls, processed, spec):
'''
if isinstance(spec, basestring):
processed['ifname'] = spec
return {'ifname': spec}
elif isinstance(spec, int):
processed['index'] = spec
return processed
return {'index': spec}
return spec

def complete_key(self, key):
if isinstance(key, dict):
Expand Down
10 changes: 5 additions & 5 deletions pyroute2/ndb/objects/netns.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def __init__(self, *argv, **kwarg):
super(NetNS, self).__init__(*argv, **kwarg)

@classmethod
def spec_normalize(cls, processed, spec):
def spec_normalize(cls, spec):
if isinstance(spec, basestring):
processed['path'] = spec
path = netns._get_netnspath(processed['path'])
spec = {'path': spec}
path = netns._get_netnspath(spec['path'])
# on Python3 _get_netnspath() returns bytes, not str, so
# we have to decode it here in order to avoid issues with
# cache keys and DB inserts
if hasattr(path, 'decode'):
path = path.decode('utf-8')
processed['path'] = path
return processed
spec['path'] = path
return spec

def __setitem__(self, key, value):
if self.state == 'system':
Expand Down
6 changes: 3 additions & 3 deletions pyroute2/ndb/objects/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,10 @@ def dump(cls, view):
yield record

@classmethod
def spec_normalize(cls, processed, spec):
def spec_normalize(cls, spec):
if isinstance(spec, basestring):
processed['dst'] = spec
return processed
return {'dst': spec}
return spec

@classmethod
def compare_record(self, left, right):
Expand Down
9 changes: 8 additions & 1 deletion pyroute2/netlink/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,14 @@ def __exit__(self, exc_type, exc_value, traceback):
self.close()

def __getattr__(self, key):
if key in ('pid', 'send', 'recv', 'sendto'):
if key in (
'pid',
'send',
'recv',
'sendto',
'register_policy',
'unregister_policy',
):
return getattr(self.asyncore, key)
raise AttributeError(key)

Expand Down
8 changes: 2 additions & 6 deletions pyroute2/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,9 @@ def match_one_message(self, msg):
matches = []
for key in self.dump_filter:
# get the attribute
if isinstance(key, str):
nkey = (key,)
elif isinstance(key, tuple):
nkey = key
else:
if not isinstance(key, (str, tuple)):
continue
value = msg.get(*nkey)
value = msg.get(key)
if value is not None and callable(self.dump_filter[key]):
matches.append(self.dump_filter[key](value))
else:
Expand Down
1 change: 1 addition & 0 deletions pyroute2/requests/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, context=None, prime=None):
self.reset_filters()
self.reset_mark()
self.parameters = {}
prime = {} if prime is None else prime
self.context = (
context if isinstance(context, (dict, weakref.ProxyType)) else {}
)
Expand Down
9 changes: 7 additions & 2 deletions pyroute2/requests/tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ def finalize(self, context):
'parent', getattr(plugin, 'parent', 0)
)
if set(context.keys()) > set(('kind', 'index', 'handle')):
get_parameters = None
if self.command[-5:] == 'class':
context['options'] = plugin.get_class_parameters(context)
get_parameters = getattr(
plugin, 'get_class_parameters', None
)
else:
context['options'] = plugin.get_parameters(context)
get_parameters = getattr(plugin, 'get_parameters', None)
if get_parameters is not None:
context['options'] = get_parameters(dict(context))
if hasattr(plugin, 'fix_request'):
plugin.fix_request(context)
3 changes: 1 addition & 2 deletions tests/test_linux/pr2test/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pyroute2.iproute.linux import IPRoute
from pyroute2.nslink.nslink import NetNS
from pyroute2 import IPRoute, NetNS


def interface_exists(netns=None, *argv, **kwarg):
Expand Down
Loading

0 comments on commit c7d4d1a

Please sign in to comment.