From 0b2d75c09a08b4551117a4bb5b1fbb05710b9856 Mon Sep 17 00:00:00 2001 From: Yang Sun <sunytt@google.com> Date: Wed, 6 Dec 2023 19:46:56 +0800 Subject: [PATCH] Add support for processing netlink dump messages Test: atest NetworkStaticLibTests:com.android.net.moduletests.util.netlink.NetlinkUtilsTest Change-Id: I15e208dca5ed6a723585f79cf09276017bc8a885 --- .../net/module/util/netlink/NetlinkUtils.java | 96 ++++++++++++- .../module/util/netlink/NetlinkUtilsTest.java | 131 +++++++----------- 2 files changed, 138 insertions(+), 89 deletions(-) diff --git a/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java b/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java index f1f30d37f3..f6282fdf5a 100644 --- a/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java +++ b/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java @@ -29,7 +29,12 @@ import static android.system.OsConstants.SOL_SOCKET; import static android.system.OsConstants.SO_RCVBUF; import static android.system.OsConstants.SO_RCVTIMEO; import static android.system.OsConstants.SO_SNDTIMEO; +import static com.android.net.module.util.netlink.NetlinkConstants.hexify; +import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE; +import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP; +import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST; +import android.net.ParseException; import android.net.util.SocketUtils; import android.system.ErrnoException; import android.system.Os; @@ -47,7 +52,11 @@ import java.net.Inet6Address; import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import java.util.function.Consumer; +import java.util.stream.Collectors; /** * Utilities for netlink related class that may not be able to fit into a specific class. @@ -163,11 +172,7 @@ public class NetlinkUtils { Log.e(TAG, errPrefix, e); throw new ErrnoException(errPrefix, EIO, e); } finally { - try { - SocketUtils.closeSocket(fd); - } catch (IOException e) { - // Nothing we can do here - } + closeSocketQuietly(fd); } } @@ -308,4 +313,85 @@ public class NetlinkUtils { } private NetlinkUtils() {} + + /** + * Sends a netlink dump request and processes the returned dump messages + * + * @param <T> extends NetlinkMessage + * @param dumpRequestMessage netlink dump request message to be sent + * @param nlFamily netlink family + * @param msgClass expected class of the netlink message + * @param func function defined by caller to handle the dump messages + * @throws SocketException when fails to create socket + * @throws InterruptedIOException when fails to read the dumpFd + * @throws ErrnoException when fails to send dump request + * @throws ParseException when message can't be parsed + */ + public static <T extends NetlinkMessage> void getAndProcessNetlinkDumpMessages( + byte[] dumpRequestMessage, int nlFamily, Class<T> msgClass, + Consumer<T> func) + throws SocketException, InterruptedIOException, ErrnoException, ParseException { + // Create socket and send dump request + final FileDescriptor fd; + try { + fd = netlinkSocketForProto(nlFamily); + } catch (ErrnoException e) { + Log.e(TAG, "Failed to create netlink socket " + e); + throw e.rethrowAsSocketException(); + } + + try { + connectToKernel(fd); + } catch (ErrnoException | SocketException e) { + Log.e(TAG, "Failed to connect netlink socket to kernel " + e); + closeSocketQuietly(fd); + return; + } + + try { + sendMessage(fd, dumpRequestMessage, 0, dumpRequestMessage.length, IO_TIMEOUT_MS); + } catch (InterruptedIOException | ErrnoException e) { + Log.e(TAG, "Failed to send dump request " + e); + closeSocketQuietly(fd); + throw e; + } + + while (true) { + final ByteBuffer buf = recvMessage( + fd, NetlinkUtils.DEFAULT_RECV_BUFSIZE, IO_TIMEOUT_MS); + + while (buf.remaining() > 0) { + final int position = buf.position(); + final NetlinkMessage nlMsg = NetlinkMessage.parse(buf, nlFamily); + if (nlMsg == null) { + // Move to the position where parse started for error log. + buf.position(position); + Log.e(TAG, "Failed to parse netlink message: " + hexify(buf)); + closeSocketQuietly(fd); + throw new ParseException("Failed to parse netlink message"); + } + + if (nlMsg.getHeader().nlmsg_type == NLMSG_DONE) { + closeSocketQuietly(fd); + return; + } + + if (!msgClass.isInstance(nlMsg)) { + Log.e(TAG, "Received unexpected netlink message: " + nlMsg); + continue; + } + + final T msg = (T) nlMsg; + func.accept(msg); + } + } + } + + private static void closeSocketQuietly(final FileDescriptor fd) { + try { + SocketUtils.closeSocket(fd); + } catch (IOException e) { + // Nothing we can do here + } + } } diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java index 5a231fc523..17d4e81671 100644 --- a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java +++ b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java @@ -55,6 +55,9 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; @RunWith(AndroidJUnit4.class) @SmallTest @@ -65,19 +68,14 @@ public class NetlinkUtilsTest { @Test public void testGetNeighborsQuery() throws Exception { - final FileDescriptor fd = NetlinkUtils.netlinkSocketForProto(NETLINK_ROUTE); - assertNotNull(fd); - - NetlinkUtils.connectToKernel(fd); - - final NetlinkSocketAddress localAddr = (NetlinkSocketAddress) Os.getsockname(fd); - assertNotNull(localAddr); - assertEquals(0, localAddr.getGroupsMask()); - assertTrue(0 != localAddr.getPortId()); - final byte[] req = RtNetlinkNeighborMessage.newGetNeighborsRequest(TEST_SEQNO); assertNotNull(req); + List<RtNetlinkNeighborMessage> msgs = new ArrayList<>(); + Consumer<RtNetlinkNeighborMessage> handleNlDumpMsg = (msg) -> { + msgs.add(msg); + }; + final Context ctx = InstrumentationRegistry.getInstrumentation().getContext(); final int targetSdk = ctx.getPackageManager() @@ -94,7 +92,8 @@ public class NetlinkUtilsTest { assumeFalse("network_stack context is expected to have permission to send RTM_GETNEIGH", ctxt.startsWith("u:r:network_stack:s0")); try { - NetlinkUtils.sendMessage(fd, req, 0, req.length, TEST_TIMEOUT_MS); + NetlinkUtils.<RtNetlinkNeighborMessage>getAndProcessNetlinkDumpMessages(req, + NETLINK_ROUTE, RtNetlinkNeighborMessage.class, handleNlDumpMsg); fail("RTM_GETNEIGH is not allowed for apps targeting SDK > 31 on T+ platforms," + " target SDK version: " + targetSdk); } catch (ErrnoException e) { @@ -105,106 +104,70 @@ public class NetlinkUtilsTest { } // Check that apps targeting lower API levels / running on older platforms succeed - assertEquals(req.length, - NetlinkUtils.sendMessage(fd, req, 0, req.length, TEST_TIMEOUT_MS)); - - int neighMessageCount = 0; - int doneMessageCount = 0; - - while (doneMessageCount == 0) { - ByteBuffer response = - NetlinkUtils.recvMessage(fd, DEFAULT_RECV_BUFSIZE, TEST_TIMEOUT_MS); - assertNotNull(response); - assertTrue(StructNlMsgHdr.STRUCT_SIZE <= response.limit()); - assertEquals(0, response.position()); - assertEquals(ByteOrder.nativeOrder(), response.order()); - - // Verify the messages at least appears minimally reasonable. - while (response.remaining() > 0) { - final NetlinkMessage msg = NetlinkMessage.parse(response, NETLINK_ROUTE); - assertNotNull(msg); - final StructNlMsgHdr hdr = msg.getHeader(); - assertNotNull(hdr); - - if (hdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) { - doneMessageCount++; - continue; - } - - assertEquals(NetlinkConstants.RTM_NEWNEIGH, hdr.nlmsg_type); - assertTrue(msg instanceof RtNetlinkNeighborMessage); - assertTrue((hdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0); - assertEquals(TEST_SEQNO, hdr.nlmsg_seq); - assertEquals(localAddr.getPortId(), hdr.nlmsg_pid); - - neighMessageCount++; - } + NetlinkUtils.<RtNetlinkNeighborMessage>getAndProcessNetlinkDumpMessages(req, + NETLINK_ROUTE, RtNetlinkNeighborMessage.class, handleNlDumpMsg); + + for (var msg : msgs) { + assertNotNull(msg); + final StructNlMsgHdr hdr = msg.getHeader(); + assertNotNull(hdr); + assertEquals(NetlinkConstants.RTM_NEWNEIGH, hdr.nlmsg_type); + assertTrue((hdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0); + assertEquals(TEST_SEQNO, hdr.nlmsg_seq); } - assertEquals(1, doneMessageCount); // TODO: make sure this test passes sanely in airplane mode. - assertTrue(neighMessageCount > 0); - - IoUtils.closeQuietly(fd); + assertTrue(msgs.size() > 0); } @Test public void testBasicWorkingGetAddrQuery() throws Exception { - final FileDescriptor fd = NetlinkUtils.netlinkSocketForProto(NETLINK_ROUTE); - assertNotNull(fd); - - NetlinkUtils.connectToKernel(fd); - - final NetlinkSocketAddress localAddr = (NetlinkSocketAddress) Os.getsockname(fd); - assertNotNull(localAddr); - assertEquals(0, localAddr.getGroupsMask()); - assertTrue(0 != localAddr.getPortId()); - final int testSeqno = 8; final byte[] req = newGetAddrRequest(testSeqno); assertNotNull(req); - final long timeout = 500; - assertEquals(req.length, NetlinkUtils.sendMessage(fd, req, 0, req.length, timeout)); - - int addrMessageCount = 0; + List<RtNetlinkAddressMessage> msgs = new ArrayList<>(); + Consumer<RtNetlinkAddressMessage> handleNlDumpMsg = (msg) -> { + msgs.add(msg); + }; + NetlinkUtils.<RtNetlinkAddressMessage>getAndProcessNetlinkDumpMessages(req, NETLINK_ROUTE, + RtNetlinkAddressMessage.class, handleNlDumpMsg); - while (true) { - ByteBuffer response = NetlinkUtils.recvMessage(fd, DEFAULT_RECV_BUFSIZE, timeout); - assertNotNull(response); - assertTrue(StructNlMsgHdr.STRUCT_SIZE <= response.limit()); - assertEquals(0, response.position()); - assertEquals(ByteOrder.nativeOrder(), response.order()); + boolean ipv4LoopbackAddressFound = false; + boolean ipv6LoopbackAddressFound = false; + final InetAddress loopbackIpv4 = InetAddress.getByName("127.0.0.1"); + final InetAddress loopbackIpv6 = InetAddress.getByName("::1"); - final NetlinkMessage msg = NetlinkMessage.parse(response, NETLINK_ROUTE); + for (var msg : msgs) { assertNotNull(msg); final StructNlMsgHdr nlmsghdr = msg.getHeader(); assertNotNull(nlmsghdr); - - if (nlmsghdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) { - break; - } - assertEquals(NetlinkConstants.RTM_NEWADDR, nlmsghdr.nlmsg_type); assertTrue((nlmsghdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0); assertEquals(testSeqno, nlmsghdr.nlmsg_seq); - assertEquals(localAddr.getPortId(), nlmsghdr.nlmsg_pid); assertTrue(msg instanceof RtNetlinkAddressMessage); - addrMessageCount++; - - // From the query response we can see the RTM_NEWADDR messages representing for IPv4 - // and IPv6 loopback address: 127.0.0.1 and ::1. + // When parsing the full response we can see the RTM_NEWADDR messages representing for + // IPv4 and IPv6 loopback address: 127.0.0.1 and ::1 and non-loopback addresses. final StructIfaddrMsg ifaMsg = ((RtNetlinkAddressMessage) msg).getIfaddrHeader(); final InetAddress ipAddress = ((RtNetlinkAddressMessage) msg).getIpAddress(); assertTrue( "Non-IP address family: " + ifaMsg.family, ifaMsg.family == AF_INET || ifaMsg.family == AF_INET6); - assertTrue(ipAddress.isLoopbackAddress()); - } + assertNotNull(ipAddress); - assertTrue(addrMessageCount > 0); + if (ipAddress.equals(loopbackIpv4)) { + ipv4LoopbackAddressFound = true; + assertTrue(ipAddress.isLoopbackAddress()); + } + if (ipAddress.equals(loopbackIpv6)) { + ipv6LoopbackAddressFound = true; + assertTrue(ipAddress.isLoopbackAddress()); + } + } - IoUtils.closeQuietly(fd); + assertTrue(msgs.size() > 0); + // Check ipv4 and ipv6 loopback addresses are in the output + assertTrue(ipv4LoopbackAddressFound && ipv6LoopbackAddressFound); } /** A convenience method to create an RTM_GETADDR request message. */ -- GitLab