Commit 2557e2ec authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'netlink-add-nftables-spec-w-multi-messages'

Donald Hunter says:

====================
netlink: Add nftables spec w/ multi messages

This series adds a ynl spec for nftables and extends ynl with a --multi
command line option that makes it possible to send transactional batches
for nftables.

This series includes a patch for nfnetlink which adds ACK processing for
batch begin/end messages. If you'd prefer that to be sent separately to
nf-next then I can do so, but I included it here so that it gets seen in
context.

An example of usage is:

./tools/net/ynl/cli.py \
 --spec Documentation/netlink/specs/nftables.yaml \
 --multi batch-begin '{"res-id": 10}' \
 --multi newtable '{"name": "test", "nfgen-family": 1}' \
 --multi newchain '{"name": "chain", "table": "test", "nfgen-family": 1}' \
 --multi batch-end '{"res-id": 10}'
[None, None, None, None]

It can also be used for bundling get requests:

./tools/net/ynl/cli.py \
 --spec Documentation/netlink/specs/nftables.yaml \
 --multi gettable '{"name": "test", "nfgen-family": 1}' \
 --multi getchain '{"name": "chain", "table": "test", "nfgen-family": 1}' \
 --output-json
[{"name": "test", "use": 1, "handle": 1, "flags": [],
 "nfgen-family": 1, "version": 0, "res-id": 2},
 {"table": "test", "name": "chain", "handle": 1, "use": 0,
 "nfgen-family": 1, "version": 0, "res-id": 2}]

There are 2 issues that may be worth resolving:

 - ynl reports errors by raising an NlError exception so only the first
   error gets reported. This could be changed to add errors to the list
   of responses so that multiple errors could be reported.

 - If any message does not get a response (e.g. batch-begin w/o patch 2)
   then ynl waits indefinitely. A recv timeout could be added which
   would allow ynl to terminate.
====================

Link: https://lore.kernel.org/r/20240418104737.77914-1-donald.hunter@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents af046fd1 bf2ac490
Loading
Loading
Loading
Loading
+1264 −0

File added.

Preview size limit exceeded, changes collapsed.

+5 −0
Original line number Diff line number Diff line
@@ -427,6 +427,9 @@ static void nfnetlink_rcv_batch(struct sk_buff *skb, struct nlmsghdr *nlh,

	nfnl_unlock(subsys_id);

	if (nlh->nlmsg_flags & NLM_F_ACK)
		nfnl_err_add(&err_list, nlh, 0, &extack);

	while (skb->len >= nlmsg_total_size(0)) {
		int msglen, type;

@@ -573,6 +576,8 @@ static void nfnetlink_rcv_batch(struct sk_buff *skb, struct nlmsghdr *nlh,
		} else if (err) {
			ss->abort(net, oskb, NFNL_ABORT_NONE);
			netlink_ack(oskb, nlmsg_hdr(oskb), err, NULL);
		} else if (nlh->nlmsg_flags & NLM_F_ACK) {
			nfnl_err_add(&err_list, nlh, 0, &extack);
		}
	} else {
		enum nfnl_abort_action abort_action;
+22 −3
Original line number Diff line number Diff line
@@ -19,13 +19,28 @@ class YnlEncoder(json.JSONEncoder):


def main():
    parser = argparse.ArgumentParser(description='YNL CLI sample')
    description = """
    YNL CLI utility - a general purpose netlink utility that uses YAML
    specs to drive protocol encoding and decoding.
    """
    epilog = """
    The --multi option can be repeated to include several do operations
    in the same netlink payload.
    """

    parser = argparse.ArgumentParser(description=description,
                                     epilog=epilog)
    parser.add_argument('--spec', dest='spec', type=str, required=True)
    parser.add_argument('--schema', dest='schema', type=str)
    parser.add_argument('--no-schema', action='store_true')
    parser.add_argument('--json', dest='json_text', type=str)
    parser.add_argument('--do', dest='do', type=str)
    parser.add_argument('--dump', dest='dump', type=str)

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--do', dest='do', metavar='DO-OPERATION', type=str)
    group.add_argument('--multi', dest='multi', nargs=2, action='append',
                       metavar=('DO-OPERATION', 'JSON_TEXT'), type=str)
    group.add_argument('--dump', dest='dump', metavar='DUMP-OPERATION', type=str)

    parser.add_argument('--sleep', dest='sleep', type=int)
    parser.add_argument('--subscribe', dest='ntf', type=str)
    parser.add_argument('--replace', dest='flags', action='append_const',
@@ -73,6 +88,10 @@ def main():
        if args.dump:
            reply = ynl.dump(args.dump, attrs)
            output(reply)
        if args.multi:
            ops = [ (item[0], json.loads(item[1]), args.flags or []) for item in args.multi ]
            reply = ynl.do_multi(ops)
            output(reply)
    except NlError as e:
        print(e)
        exit(1)
+55 −27
Original line number Diff line number Diff line
@@ -386,11 +386,8 @@ class NetlinkProtocol:
    def _decode(self, nl_msg):
        return nl_msg

    def decode(self, ynl, nl_msg):
    def decode(self, ynl, nl_msg, op):
        msg = self._decode(nl_msg)
        fixed_header_size = 0
        if ynl:
            op = ynl.rsp_by_value[msg.cmd()]
        fixed_header_size = ynl._struct_size(op.fixed_header)
        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
        return msg
@@ -797,7 +794,7 @@ class YnlFamily(SpecFamily):
        if 'bad-attr-offs' not in extack:
            return

        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
                                        extack['bad-attr-offs'])
@@ -922,7 +919,8 @@ class YnlFamily(SpecFamily):
                    print("Netlink done while checking for ntf!?")
                    continue

                decoded = self.nlproto.decode(self, nl_msg)
                op = self.rsp_by_value[nl_msg.cmd()]
                decoded = self.nlproto.decode(self, nl_msg, op)
                if decoded.cmd() not in self.async_msg_ids:
                    print("Unexpected msg id done while checking for ntf", decoded)
                    continue
@@ -940,16 +938,11 @@ class YnlFamily(SpecFamily):

      return op['do']['request']['attributes'].copy()

    def _op(self, method, vals, flags=None, dump=False):
        op = self.ops[method]

    def _encode_message(self, op, vals, flags, req_seq):
        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
        for flag in flags or []:
            nl_flags |= flag
        if dump:
            nl_flags |= Netlink.NLM_F_DUMP

        req_seq = random.randint(1024, 65535)
        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
        if op.fixed_header:
            msg += self._encode_struct(op.fixed_header, vals)
@@ -957,18 +950,36 @@ class YnlFamily(SpecFamily):
        for name, value in vals.items():
            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
        msg = _genl_msg_finalize(msg)
        return msg

        self.sock.send(msg, 0)
    def _ops(self, ops):
        reqs_by_seq = {}
        req_seq = random.randint(1024, 65535)
        payload = b''
        for (method, vals, flags) in ops:
            op = self.ops[method]
            msg = self._encode_message(op, vals, flags, req_seq)
            reqs_by_seq[req_seq] = (op, msg, flags)
            payload += msg
            req_seq += 1

        self.sock.send(payload, 0)

        done = False
        rsp = []
        op_rsp = []
        while not done:
            reply = self.sock.recv(self._recv_size)
            nms = NlMsgs(reply, attr_space=op.attr_set)
            self._recv_dbg_print(reply, nms)
            for nl_msg in nms:
                if nl_msg.nl_seq in reqs_by_seq:
                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
                    if nl_msg.extack:
                    self._decode_extack(msg, op, nl_msg.extack)
                        self._decode_extack(req_msg, op, nl_msg.extack)
                else:
                    op = self.rsp_by_value[nl_msg.cmd()]
                    req_flags = []

                if nl_msg.error:
                    raise NlError(nl_msg)
@@ -976,13 +987,25 @@ class YnlFamily(SpecFamily):
                    if nl_msg.extack:
                        print("Netlink warning:")
                        print(nl_msg)
                    done = True

                    if Netlink.NLM_F_DUMP in req_flags:
                        rsp.append(op_rsp)
                    elif not op_rsp:
                        rsp.append(None)
                    elif len(op_rsp) == 1:
                        rsp.append(op_rsp[0])
                    else:
                        rsp.append(op_rsp)
                    op_rsp = []

                    del reqs_by_seq[nl_msg.nl_seq]
                    done = len(reqs_by_seq) == 0
                    break

                decoded = self.nlproto.decode(self, nl_msg)
                decoded = self.nlproto.decode(self, nl_msg, op)

                # Check if this is a reply to our request
                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
                    if decoded.cmd() in self.async_msg_ids:
                        self.handle_ntf(decoded)
                        continue
@@ -993,18 +1016,23 @@ class YnlFamily(SpecFamily):
                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
                if op.fixed_header:
                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
                rsp.append(rsp_msg)
                op_rsp.append(rsp_msg)

        if dump:
            return rsp
        if not rsp:
            return None
        if len(rsp) == 1:
            return rsp[0]
        return rsp

    def _op(self, method, vals, flags=None, dump=False):
        req_flags = flags or []
        if dump:
            req_flags.append(Netlink.NLM_F_DUMP)

        ops = [(method, vals, req_flags)]
        return self._ops(ops)[0]

    def do(self, method, vals, flags=None):
        return self._op(method, vals, flags)

    def dump(self, method, vals):
        return self._op(method, vals, [], dump=True)
        return self._op(method, vals, dump=True)

    def do_multi(self, ops):
        return self._ops(ops)