diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java b/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java index ea3af5e83b8b4ff401a718ca6d26139be29c1aab..651b643a3cf186fd2fae899b1e2da0c59c3219c7 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java @@ -25,6 +25,7 @@ import android.os.Handler; import android.os.Looper; import android.os.Message; +import com.android.internal.annotations.VisibleForTesting; import com.android.net.module.util.SharedLog; import com.android.server.connectivity.mdns.util.MdnsUtils; @@ -57,15 +58,46 @@ public class MdnsReplySender { @NonNull private final SharedLog mSharedLog; private final boolean mEnableDebugLog; + @NonNull + private final Dependencies mDependencies; + + /** + * Dependencies of MdnsReplySender, for injection in tests. + */ + @VisibleForTesting + public static class Dependencies { + /** + * @see Handler#sendMessageDelayed(Message, long) + */ + public void sendMessageDelayed(@NonNull Handler handler, @NonNull Message message, + long delayMillis) { + handler.sendMessageDelayed(message, delayMillis); + } + + /** + * @see Handler#removeMessages(int) + */ + public void removeMessages(@NonNull Handler handler, int what) { + handler.removeMessages(what); + } + } public MdnsReplySender(@NonNull Looper looper, @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog, boolean enableDebugLog) { + this(looper, socket, packetCreationBuffer, sharedLog, enableDebugLog, new Dependencies()); + } + + @VisibleForTesting + public MdnsReplySender(@NonNull Looper looper, @NonNull MdnsInterfaceSocket socket, + @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog, + boolean enableDebugLog, @NonNull Dependencies dependencies) { mHandler = new SendHandler(looper); mSocket = socket; mPacketCreationBuffer = packetCreationBuffer; mSharedLog = sharedLog; mEnableDebugLog = enableDebugLog; + mDependencies = dependencies; } /** @@ -74,7 +106,8 @@ public class MdnsReplySender { public void queueReply(@NonNull MdnsReplyInfo reply) { ensureRunningOnHandlerThread(mHandler); // TODO: implement response aggregation (RFC 6762 6.4) - mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs); + mDependencies.sendMessageDelayed( + mHandler, mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs); if (mEnableDebugLog) { mSharedLog.v("Scheduling " + reply); @@ -104,7 +137,7 @@ public class MdnsReplySender { */ public void cancelAll() { ensureRunningOnHandlerThread(mHandler); - mHandler.removeMessages(MSG_SEND); + mDependencies.removeMessages(mHandler, MSG_SEND); } private class SendHandler extends Handler { diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..9e2933f5b4ba77d8c234c9a4c9d0cecc975ab734 --- /dev/null +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.connectivity.mdns + +import android.net.InetAddresses +import android.net.LinkAddress +import android.os.Build +import android.os.Handler +import android.os.HandlerThread +import android.os.Message +import com.android.net.module.util.SharedLog +import com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR +import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo +import com.android.testutils.DevSdkIgnoreRunner +import java.net.InetSocketAddress +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.argThat +import org.mockito.Mockito.doReturn +import org.mockito.Mockito.eq +import org.mockito.Mockito.mock +import org.mockito.Mockito.timeout +import org.mockito.Mockito.verify + +private const val TEST_PORT = 12345 +private const val DEFAULT_TIMEOUT_MS = 2000L +private const val LONG_TTL = 4_500_000L +private const val SHORT_TTL = 120_000L + +@RunWith(DevSdkIgnoreRunner::class) +@IgnoreUpTo(Build.VERSION_CODES.S_V2) +class MdnsReplySenderTest { + private val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + private val serviceType = arrayOf("_testservice", "_tcp", "local") + private val hostname = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local") + private val hostAddresses = listOf( + LinkAddress(InetAddresses.parseNumericAddress("192.0.2.111"), 24), + LinkAddress(InetAddresses.parseNumericAddress("2001:db8::111"), 64), + LinkAddress(InetAddresses.parseNumericAddress("2001:db8::222"), 64)) + private val answers = listOf( + MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */, false /* cacheFlush */, + LONG_TTL, serviceName)) + private val additionalAnswers = listOf( + MdnsTextRecord(serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, LONG_TTL, + listOf() /* entries */), + MdnsServiceRecord(serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, + SHORT_TTL, 0 /* servicePriority */, 0 /* serviceWeight */, TEST_PORT, hostname), + MdnsInetAddressRecord(hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */, + SHORT_TTL, hostAddresses[0].address), + MdnsInetAddressRecord(hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */, + SHORT_TTL, hostAddresses[1].address), + MdnsInetAddressRecord(hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */, + SHORT_TTL, hostAddresses[2].address), + MdnsNsecRecord(serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, LONG_TTL, + serviceName /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)), + MdnsNsecRecord(hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */, SHORT_TTL, + hostname /* nextDomain */, intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA))) + private val thread = HandlerThread(MdnsReplySenderTest::class.simpleName) + private val socket = mock(MdnsInterfaceSocket::class.java) + private val buffer = ByteArray(1500) + private val sharedLog = SharedLog(MdnsReplySenderTest::class.simpleName) + private val deps = mock(MdnsReplySender.Dependencies::class.java) + private val handler by lazy { Handler(thread.looper) } + private val replySender by lazy { + MdnsReplySender(thread.looper, socket, buffer, sharedLog, false /* enableDebugLog */, deps) + } + + @Before + fun setUp() { + thread.start() + doReturn(true).`when`(socket).hasJoinedIpv4() + doReturn(true).`when`(socket).hasJoinedIpv6() + } + + @After + fun tearDown() { + thread.quitSafely() + thread.join() + } + + private fun <T> runningOnHandlerAndReturn(functor: (() -> T)): T { + val future = CompletableFuture<T>() + handler.post { + future.complete(functor()) + } + return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + + private fun sendNow(packet: MdnsPacket, destination: InetSocketAddress): + Unit = runningOnHandlerAndReturn { replySender.sendNow(packet, destination) } + + private fun queueReply(reply: MdnsReplyInfo): + Unit = runningOnHandlerAndReturn { replySender.queueReply(reply) } + + @Test + fun testSendNow() { + val packet = MdnsPacket(0x8400, + listOf() /* questions */, + answers, + listOf() /* authorityRecords */, + additionalAnswers) + sendNow(packet, IPV4_SOCKET_ADDR) + verify(socket).send(argThat{ it.socketAddress.equals(IPV4_SOCKET_ADDR) }) + } + + @Test + fun testQueueReply() { + val reply = MdnsReplyInfo(answers, additionalAnswers, 20L /* sendDelayMs */, + IPV4_SOCKET_ADDR) + val handlerCaptor = ArgumentCaptor.forClass(Handler::class.java) + val messageCaptor = ArgumentCaptor.forClass(Message::class.java) + queueReply(reply) + verify(deps).sendMessageDelayed(handlerCaptor.capture(), messageCaptor.capture(), eq(20L)) + + val realHandler = handlerCaptor.value + val delayMessage = messageCaptor.value + realHandler.sendMessage(delayMessage) + verify(socket, timeout(DEFAULT_TIMEOUT_MS)).send(argThat{ + it.socketAddress.equals(IPV4_SOCKET_ADDR) + }) + } +}