Commit 3287e812 authored by Ilya Maximets's avatar Ilya Maximets Committed by Jakub Kicinski
Browse files

tools: ynl: support listening on all nsids



A new method ntf_listen_all_nsid() to enable listening on events from
all namespaces.  Useful for testing cross-namespace functionality.

recv() replaced with recvmsg() to be able to receive NSID through the
ancillary data.

Signed-off-by: default avatarIlya Maximets <i.maximets@ovn.org>
Link: https://patch.msgid.link/20260520172317.175168-4-i.maximets@ovn.org


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 4db79a32
Loading
Loading
Loading
Loading
+32 −5
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ class Netlink:
    SOL_NETLINK = 270

    NETLINK_ADD_MEMBERSHIP = 1
    NETLINK_LISTEN_ALL_NSID = 8
    NETLINK_CAP_ACK = 10
    NETLINK_EXT_ACK = 11
    NETLINK_GET_STRICT_CHK = 12
@@ -680,6 +681,7 @@ class YnlFamily(SpecFamily):
    Notification API:

      ynl.ntf_subscribe(mcast_name)      -- join a multicast group
      ynl.ntf_listen_all_nsid()          -- listen on all netns
      ynl.check_ntf()                    -- drain pending notifications
      ynl.poll_ntf(duration=None)        -- yield notifications

@@ -748,6 +750,23 @@ class YnlFamily(SpecFamily):
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
                             mcast_id)

    def ntf_listen_all_nsid(self):
        """Enable NETLINK_LISTEN_ALL_NSID to receive notifications from all
        namespaces that have an nsid mapped in the current one."""
        self.sock.setsockopt(Netlink.SOL_NETLINK,
                             Netlink.NETLINK_LISTEN_ALL_NSID, 1)

    @staticmethod
    def _decode_nsid(ancdata):
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
            if (cmsg_level == Netlink.SOL_NETLINK and
                    cmsg_type == Netlink.NETLINK_LISTEN_ALL_NSID):
                nsid = struct.unpack('i', cmsg_data)[0]
                if nsid >= 0:
                    return nsid
                return None
        return None

    def set_recv_dbg(self, enabled):
        self._recv_dbg = enabled

@@ -1235,7 +1254,7 @@ class YnlFamily(SpecFamily):
                            f" when parsing '{attr_spec['name']}'")
        return raw

    def handle_ntf(self, decoded):
    def handle_ntf(self, decoded, nsid=None):
        msg = {}
        if self.include_raw:
            msg['raw'] = decoded
@@ -1246,15 +1265,22 @@ class YnlFamily(SpecFamily):

        msg['name'] = op['name']
        msg['msg'] = attrs
        if nsid is not None:
            msg['nsid'] = nsid
        self.async_msg_queue.put(msg)

    def _recvmsg(self, flags=0):
        reply, ancdata, _, _ = self.sock.recvmsg(self._recv_size, 4096, flags)
        return reply, ancdata

    def check_ntf(self):
        while True:
            try:
                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
                reply, ancdata = self._recvmsg(socket.MSG_DONTWAIT)
            except BlockingIOError:
                return

            nsid = self._decode_nsid(ancdata)
            nms = NlMsgs(reply)
            self._recv_dbg_print(reply, nms)
            for nl_msg in nms:
@@ -1271,7 +1297,7 @@ class YnlFamily(SpecFamily):
                    print("Unexpected msg id while checking for ntf", decoded)
                    continue

                self.handle_ntf(decoded)
                self.handle_ntf(decoded, nsid)

    def poll_ntf(self, duration=None):
        start_time = time.time()
@@ -1335,7 +1361,8 @@ class YnlFamily(SpecFamily):
        rsp = []
        op_rsp = []
        while not done:
            reply = self.sock.recv(self._recv_size)
            reply, ancdata = self._recvmsg()
            nsid = self._decode_nsid(ancdata)
            nms = NlMsgs(reply)
            self._recv_dbg_print(reply, nms)
            for nl_msg in nms:
@@ -1374,7 +1401,7 @@ class YnlFamily(SpecFamily):
                # Check if this is a reply to our request
                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
                    if decoded.cmd() in self.async_msg_ids:
                        self.handle_ntf(decoded)
                        self.handle_ntf(decoded, nsid)
                        continue
                    print('Unexpected message: ' + repr(decoded))
                    continue