Skip to content

Commit

Permalink
tools: ynl: Use dict of predefined Structs to decode scalar types
Browse files Browse the repository at this point in the history
Use a dict of predefined Struct() objects to decode scalar types in native,
big or little endian format. This removes the repetitive code for the
scalar variants and ensures all the signed variants are supported.

Signed-off-by: Donald Hunter <donald.hunter@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
Donald Hunter authored and David S. Miller committed May 24, 2023
1 parent 59088b5 commit 7c2435e
Showing 1 changed file with 44 additions and 57 deletions.
101 changes: 44 additions & 57 deletions tools/net/ynl/lib/ynl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause

from collections import namedtuple
import functools
import os
import random
import socket
import struct
from struct import Struct
import yaml

from .nlspec import SpecFamily
Expand Down Expand Up @@ -76,10 +78,17 @@ def __str__(self):


class NlAttr:
type_formats = { 'u8' : ('B', 1), 's8' : ('b', 1),
'u16': ('H', 2), 's16': ('h', 2),
'u32': ('I', 4), 's32': ('i', 4),
'u64': ('Q', 8), 's64': ('q', 8) }
ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
type_formats = {
'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
}

def __init__(self, raw, offset):
self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
Expand All @@ -88,25 +97,17 @@ def __init__(self, raw, offset):
self.full_len = (self.payload_len + 3) & ~3
self.raw = raw[offset + 4:offset + self.payload_len]

def format_byte_order(byte_order):
@classmethod
def get_format(cls, attr_type, byte_order=None):
format = cls.type_formats[attr_type]
if byte_order:
return ">" if byte_order == "big-endian" else "<"
return ""
return format.big if byte_order == "big-endian" \
else format.little
return format.native

def as_u8(self):
return struct.unpack("B", self.raw)[0]

def as_u16(self, byte_order=None):
endian = NlAttr.format_byte_order(byte_order)
return struct.unpack(f"{endian}H", self.raw)[0]

def as_u32(self, byte_order=None):
endian = NlAttr.format_byte_order(byte_order)
return struct.unpack(f"{endian}I", self.raw)[0]

def as_u64(self, byte_order=None):
endian = NlAttr.format_byte_order(byte_order)
return struct.unpack(f"{endian}Q", self.raw)[0]
def as_scalar(self, attr_type, byte_order=None):
format = self.get_format(attr_type, byte_order)
return format.unpack(self.raw)[0]

def as_strz(self):
return self.raw.decode('ascii')[:-1]
Expand All @@ -115,17 +116,17 @@ def as_bin(self):
return self.raw

def as_c_array(self, type):
format, _ = self.type_formats[type]
return list({ x[0] for x in struct.iter_unpack(format, self.raw) })
format = self.get_format(type)
return [ x[0] for x in format.iter_unpack(self.raw) ]

def as_struct(self, members):
value = dict()
offset = 0
for m in members:
# TODO: handle non-scalar members
format, size = self.type_formats[m.type]
decoded = struct.unpack_from(format, self.raw, offset)
offset += size
format = self.get_format(m.type)
decoded = format.unpack_from(self.raw, offset)
offset += format.size
value[m.name] = decoded[0]
return value

Expand Down Expand Up @@ -184,11 +185,11 @@ def __init__(self, msg, offset, attr_space=None):
if extack.type == Netlink.NLMSGERR_ATTR_MSG:
self.extack['msg'] = extack.as_strz()
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
self.extack['miss-type'] = extack.as_u32()
self.extack['miss-type'] = extack.as_scalar('u32')
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
self.extack['miss-nest'] = extack.as_u32()
self.extack['miss-nest'] = extack.as_scalar('u32')
elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
self.extack['bad-attr-offs'] = extack.as_u32()
self.extack['bad-attr-offs'] = extack.as_scalar('u32')
else:
if 'unknown' not in self.extack:
self.extack['unknown'] = []
Expand Down Expand Up @@ -272,11 +273,11 @@ def _genl_load_families():
fam = dict()
for attr in gm.raw_attrs:
if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
fam['id'] = attr.as_u16()
fam['id'] = attr.as_scalar('u16')
elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
fam['name'] = attr.as_strz()
elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
fam['maxattr'] = attr.as_u32()
fam['maxattr'] = attr.as_scalar('u32')
elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
fam['mcast'] = dict()
for entry in NlAttrs(attr.raw):
Expand All @@ -286,7 +287,7 @@ def _genl_load_families():
if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
mcast_name = entry_attr.as_strz()
elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
mcast_id = entry_attr.as_u32()
mcast_id = entry_attr.as_scalar('u32')
if mcast_name and mcast_id is not None:
fam['mcast'][mcast_name] = mcast_id
if 'name' in fam and 'id' in fam:
Expand All @@ -304,9 +305,9 @@ def __init__(self, nl_msg, fixed_header_members=[]):

self.fixed_header_attrs = dict()
for m in fixed_header_members:
format, size = NlAttr.type_formats[m.type]
decoded = struct.unpack_from(format, nl_msg.raw, offset)
offset += size
format = NlAttr.get_format(m.type)
decoded = format.unpack_from(nl_msg.raw, offset)
offset += format.size
self.fixed_header_attrs[m.name] = decoded[0]

self.raw = nl_msg.raw[offset:]
Expand Down Expand Up @@ -381,21 +382,13 @@ def _add_attr(self, space, name, value):
attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
elif attr["type"] == 'flag':
attr_payload = b''
elif attr["type"] == 'u8':
attr_payload = struct.pack("B", int(value))
elif attr["type"] == 'u16':
endian = NlAttr.format_byte_order(attr.byte_order)
attr_payload = struct.pack(f"{endian}H", int(value))
elif attr["type"] == 'u32':
endian = NlAttr.format_byte_order(attr.byte_order)
attr_payload = struct.pack(f"{endian}I", int(value))
elif attr["type"] == 'u64':
endian = NlAttr.format_byte_order(attr.byte_order)
attr_payload = struct.pack(f"{endian}Q", int(value))
elif attr["type"] == 'string':
attr_payload = str(value).encode('ascii') + b'\x00'
elif attr["type"] == 'binary':
attr_payload = value
elif attr['type'] in NlAttr.type_formats:
format = NlAttr.get_format(attr['type'], attr.byte_order)
attr_payload = format.pack(int(value))
else:
raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')

Expand Down Expand Up @@ -434,22 +427,16 @@ def _decode(self, attrs, space):
if attr_spec["type"] == 'nest':
subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
decoded = subdict
elif attr_spec['type'] == 'u8':
decoded = attr.as_u8()
elif attr_spec['type'] == 'u16':
decoded = attr.as_u16(attr_spec.byte_order)
elif attr_spec['type'] == 'u32':
decoded = attr.as_u32(attr_spec.byte_order)
elif attr_spec['type'] == 'u64':
decoded = attr.as_u64(attr_spec.byte_order)
elif attr_spec["type"] == 'string':
decoded = attr.as_strz()
elif attr_spec["type"] == 'binary':
decoded = self._decode_binary(attr, attr_spec)
elif attr_spec["type"] == 'flag':
decoded = True
elif attr_spec["type"] in NlAttr.type_formats:
decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
else:
raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}')
raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')

if not attr_spec.is_multi:
rsp[attr_spec['name']] = decoded
Expand Down Expand Up @@ -555,8 +542,8 @@ def _op(self, method, vals, dump=False):
fixed_header_members = self.consts[op.fixed_header].members
for m in fixed_header_members:
value = vals.pop(m.name)
format, _ = NlAttr.type_formats[m.type]
msg += struct.pack(format, value)
format = NlAttr.get_format(m.type)
msg += format.pack(value)
for name, value in vals.items():
msg += self._add_attr(op.attr_set.name, name, value)
msg = _genl_msg_finalize(msg)
Expand Down

0 comments on commit 7c2435e

Please sign in to comment.