diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index 93ccb8529553560a8c28a20b292c76d2801ccc05..b04126038ed93282f9a53c28556956d06e7622c3 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -1709,9 +1709,14 @@ public class NsdService extends INsdManager.Stub { mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager")); handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager)); - MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder().setIsMdnsOffloadFeatureEnabled( - mDeps.isTetheringFeatureNotChickenedOut(mContext, - MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD)).build(); + MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder() + .setIsMdnsOffloadFeatureEnabled( + mDeps.isTetheringFeatureNotChickenedOut( + mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD)) + .setIncludeInetAddressRecordsInProbing( + mDeps.isFeatureEnabled( + mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING)) + .build(); mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider, new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser"), flags); mClock = deps.makeClock(); diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java index a946bca40b9bb8393f65aa8e0543dd8015666383..28e392494234c1258312ec196d57a520df8dc111 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java @@ -96,10 +96,11 @@ public class MdnsAdvertiser { @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull MdnsInterfaceAdvertiser.Callback cb, @NonNull String[] deviceHostName, - @NonNull SharedLog sharedLog) { + @NonNull SharedLog sharedLog, + @NonNull MdnsFeatureFlags mdnsFeatureFlags) { // Note NetworkInterface is final and not mockable return new MdnsInterfaceAdvertiser(socket, initialAddresses, looper, - packetCreationBuffer, cb, deviceHostName, sharedLog); + packetCreationBuffer, cb, deviceHostName, sharedLog, mdnsFeatureFlags); } /** @@ -394,7 +395,8 @@ public class MdnsAdvertiser { if (advertiser == null) { advertiser = mDeps.makeAdvertiser(socket, addresses, mLooper, mPacketCreationBuffer, mInterfaceAdvertiserCb, mDeviceHostName, - mSharedLog.forSubComponent(socket.getInterface().getName())); + mSharedLog.forSubComponent(socket.getInterface().getName()), + mMdnsFeatureFlags); mAllAdvertisers.put(socket, advertiser); advertiser.start(); } diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java index 9840409194c1e11c0bb7d0fb086713bc285ab08a..a6f78716da3484d00822ad533c5ff32f337b6a0b 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java @@ -24,14 +24,26 @@ public class MdnsFeatureFlags { */ public static final String NSD_FORCE_DISABLE_MDNS_OFFLOAD = "nsd_force_disable_mdns_offload"; + /** + * The feature flag for controlling whether the probing question should include + * InetAddressRecords or not. + */ + public static final String INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING = + "include_inet_address_records_in_probing"; + // Flag for offload feature public final boolean mIsMdnsOffloadFeatureEnabled; + // Flag for including InetAddressRecords in probing questions. + public final boolean mIncludeInetAddressRecordsInProbing; + /** * The constructor for {@link MdnsFeatureFlags}. */ - public MdnsFeatureFlags(boolean isOffloadFeatureEnabled) { + public MdnsFeatureFlags(boolean isOffloadFeatureEnabled, + boolean includeInetAddressRecordsInProbing) { mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled; + mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing; } @@ -44,12 +56,14 @@ public class MdnsFeatureFlags { public static final class Builder { private boolean mIsMdnsOffloadFeatureEnabled; + private boolean mIncludeInetAddressRecordsInProbing; /** * The constructor for {@link Builder}. */ public Builder() { mIsMdnsOffloadFeatureEnabled = false; + mIncludeInetAddressRecordsInProbing = false; } /** @@ -60,11 +74,21 @@ public class MdnsFeatureFlags { return this; } + /** + * Set if the probing question should include InetAddressRecords. + */ + public Builder setIncludeInetAddressRecordsInProbing( + boolean includeInetAddressRecordsInProbing) { + mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing; + return this; + } + /** * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder. */ public MdnsFeatureFlags build() { - return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled); + return new MdnsFeatureFlags( + mIsMdnsOffloadFeatureEnabled, mIncludeInetAddressRecordsInProbing); } } diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInetAddressRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInetAddressRecord.java index dd8a526c2a119d5133bb1c147c2eb4853b1b5b23..973fd9602e7874c45dc80db0b2ce8ad730beb70a 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsInetAddressRecord.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInetAddressRecord.java @@ -18,7 +18,7 @@ package com.android.server.connectivity.mdns; import android.annotation.Nullable; -import com.android.internal.annotations.VisibleForTesting; +import androidx.annotation.VisibleForTesting; import java.io.IOException; import java.net.Inet4Address; @@ -29,7 +29,7 @@ import java.util.Locale; import java.util.Objects; /** An mDNS "AAAA" or "A" record, which holds an IPv6 or IPv4 address. */ -@VisibleForTesting +@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE) public class MdnsInetAddressRecord extends MdnsRecord { @Nullable private Inet6Address inet6Address; @Nullable private Inet4Address inet4Address; diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java index 37e9743f09178bab728b43619e740a96bb62b488..e07d380cc67728f44ed75604b588c55b81804558 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java @@ -150,8 +150,8 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand /** @see MdnsRecordRepository */ @NonNull public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper, - @NonNull String[] deviceHostName) { - return new MdnsRecordRepository(looper, deviceHostName); + @NonNull String[] deviceHostName, @NonNull MdnsFeatureFlags mdnsFeatureFlags) { + return new MdnsRecordRepository(looper, deviceHostName, mdnsFeatureFlags); } /** @see MdnsReplySender */ @@ -187,16 +187,18 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb, - @NonNull String[] deviceHostName, @NonNull SharedLog sharedLog) { + @NonNull String[] deviceHostName, @NonNull SharedLog sharedLog, + @NonNull MdnsFeatureFlags mdnsFeatureFlags) { this(socket, initialAddresses, looper, packetCreationBuffer, cb, - new Dependencies(), deviceHostName, sharedLog); + new Dependencies(), deviceHostName, sharedLog, mdnsFeatureFlags); } public MdnsInterfaceAdvertiser(@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses, @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb, @NonNull Dependencies deps, - @NonNull String[] deviceHostName, @NonNull SharedLog sharedLog) { - mRecordRepository = deps.makeRecordRepository(looper, deviceHostName); + @NonNull String[] deviceHostName, @NonNull SharedLog sharedLog, + @NonNull MdnsFeatureFlags mdnsFeatureFlags) { + mRecordRepository = deps.makeRecordRepository(looper, deviceHostName, mdnsFeatureFlags); mRecordRepository.updateAddresses(initialAddresses); mSocket = socket; mCb = cb; diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsPointerRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsPointerRecord.java index c88ead09e570bd1260191a236c0b99d2caf09c2e..41cc3800299445fe2a620558e83e0dd65772b713 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsPointerRecord.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsPointerRecord.java @@ -18,14 +18,15 @@ package com.android.server.connectivity.mdns; import android.annotation.Nullable; -import com.android.internal.annotations.VisibleForTesting; +import androidx.annotation.VisibleForTesting; + import com.android.server.connectivity.mdns.util.MdnsUtils; import java.io.IOException; import java.util.Arrays; /** An mDNS "PTR" record, which holds a name (the "pointer"). */ -@VisibleForTesting +@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE) public class MdnsPointerRecord extends MdnsRecord { private String[] pointer; diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java index 130ff485d1528dca05478668d9ff0c565c2ea270..73c17583b42c76f4f19fc8540fb569758b1f7a50 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java @@ -92,16 +92,19 @@ public class MdnsRecordRepository { private final Looper mLooper; @NonNull private final String[] mDeviceHostname; + private final MdnsFeatureFlags mMdnsFeatureFlags; - public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname) { - this(looper, new Dependencies(), deviceHostname); + public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname, + @NonNull MdnsFeatureFlags mdnsFeatureFlags) { + this(looper, new Dependencies(), deviceHostname, mdnsFeatureFlags); } @VisibleForTesting public MdnsRecordRepository(@NonNull Looper looper, @NonNull Dependencies deps, - @NonNull String[] deviceHostname) { + @NonNull String[] deviceHostname, @NonNull MdnsFeatureFlags mdnsFeatureFlags) { mDeviceHostname = deviceHostname; mLooper = looper; + mMdnsFeatureFlags = mdnsFeatureFlags; } /** @@ -351,7 +354,8 @@ public class MdnsRecordRepository { } private MdnsProber.ProbingInfo makeProbingInfo(int serviceId, - @NonNull MdnsServiceRecord srvRecord) { + @NonNull MdnsServiceRecord srvRecord, + @NonNull List<MdnsInetAddressRecord> inetAddressRecords) { final List<MdnsRecord> probingRecords = new ArrayList<>(); // Probe with cacheFlush cleared; it is set when announcing, as it was verified unique: // RFC6762 10.2 @@ -363,6 +367,15 @@ public class MdnsRecordRepository { srvRecord.getServicePort(), srvRecord.getServiceHost())); + for (MdnsInetAddressRecord inetAddressRecord : inetAddressRecords) { + probingRecords.add(new MdnsInetAddressRecord(inetAddressRecord.getName(), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + inetAddressRecord.getTtl(), + inetAddressRecord.getInet4Address() == null + ? inetAddressRecord.getInet6Address() + : inetAddressRecord.getInet4Address())); + } return new MdnsProber.ProbingInfo(serviceId, probingRecords); } @@ -824,6 +837,18 @@ public class MdnsRecordRepository { return conflicting; } + private List<MdnsInetAddressRecord> makeProbingInetAddressRecords() { + final List<MdnsInetAddressRecord> records = new ArrayList<>(); + if (mMdnsFeatureFlags.mIncludeInetAddressRecordsInProbing) { + for (RecordInfo<?> record : mGeneralRecords) { + if (record.record instanceof MdnsInetAddressRecord) { + records.add((MdnsInetAddressRecord) record.record); + } + } + } + return records; + } + /** * (Re)set a service to the probing state. * @return The {@link MdnsProber.ProbingInfo} to send for probing. @@ -834,7 +859,8 @@ public class MdnsRecordRepository { if (registration == null) return null; registration.setProbing(true); - return makeProbingInfo(serviceId, registration.srvRecord.record); + return makeProbingInfo( + serviceId, registration.srvRecord.record, makeProbingInetAddressRecords()); } /** @@ -870,7 +896,8 @@ public class MdnsRecordRepository { final ServiceRegistration newService = new ServiceRegistration(mDeviceHostname, newInfo, existing.subtype, existing.repliedServiceCount, existing.sentPacketCount); mServices.put(serviceId, newService); - return makeProbingInfo(serviceId, newService.srvRecord.record); + return makeProbingInfo( + serviceId, newService.srvRecord.record, makeProbingInetAddressRecords()); } /** diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceRecord.java index f851b355050102356fb42addc66946fb68f9416e..4d407be63f784e3db7bb9e5314426bbaf7c22d84 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceRecord.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceRecord.java @@ -18,7 +18,8 @@ package com.android.server.connectivity.mdns; import android.annotation.Nullable; -import com.android.internal.annotations.VisibleForTesting; +import androidx.annotation.VisibleForTesting; + import com.android.server.connectivity.mdns.util.MdnsUtils; import java.io.IOException; @@ -27,7 +28,7 @@ import java.util.Locale; import java.util.Objects; /** An mDNS "SRV" record, which contains service information. */ -@VisibleForTesting +@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE) public class MdnsServiceRecord extends MdnsRecord { public static final int PROTO_NONE = 0; public static final int PROTO_TCP = 1; diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsTextRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsTextRecord.java index 4149dbee599692b0ce48e479ddc2c42152d24d1c..cf6c8acf2f07436bd71ac342348142e8fd0ab01e 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsTextRecord.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsTextRecord.java @@ -18,7 +18,8 @@ package com.android.server.connectivity.mdns; import android.annotation.Nullable; -import com.android.internal.annotations.VisibleForTesting; +import androidx.annotation.VisibleForTesting; + import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry; import java.io.IOException; @@ -28,7 +29,7 @@ import java.util.List; import java.util.Objects; /** An mDNS "TXT" record, which contains a list of {@link TextEntry}. */ -@VisibleForTesting +@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE) public class MdnsTextRecord extends MdnsRecord { private List<TextEntry> entries; diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt index 8eace1ccd0fa2484d62f1a3319484d15a5bd69fb..a86f923e0d85a21f1e1244c37938a5c9f9670a95 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt @@ -153,10 +153,10 @@ class MdnsAdvertiserTest { thread.start() doReturn(TEST_HOSTNAME).`when`(mockDeps).generateHostname() doReturn(mockInterfaceAdvertiser1).`when`(mockDeps).makeAdvertiser(eq(mockSocket1), - any(), any(), any(), any(), eq(TEST_HOSTNAME), any() + any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any() ) doReturn(mockInterfaceAdvertiser2).`when`(mockDeps).makeAdvertiser(eq(mockSocket2), - any(), any(), any(), any(), eq(TEST_HOSTNAME), any() + any(), any(), any(), any(), eq(TEST_HOSTNAME), any(), any() ) doReturn(true).`when`(mockInterfaceAdvertiser1).isProbing(anyInt()) doReturn(true).`when`(mockInterfaceAdvertiser2).isProbing(anyInt()) @@ -202,6 +202,7 @@ class MdnsAdvertiserTest { any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), + any(), any() ) @@ -259,10 +260,10 @@ class MdnsAdvertiserTest { val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), - eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any() + eq(thread.looper), any(), intAdvCbCaptor1.capture(), eq(TEST_HOSTNAME), any(), any() ) verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)), - eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any() + eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any(), any() ) verify(mockInterfaceAdvertiser1).addService( anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE)) @@ -367,7 +368,7 @@ class MdnsAdvertiserTest { val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)), - eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any() + eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any(), any() ) verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) }, eq(null)) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt index a67dc5e9a95dee0118085d1fda85a3ea8708a506..db41a6ad515c053de0d73083c1fcdd1cfd884b95 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt @@ -77,6 +77,7 @@ class MdnsInterfaceAdvertiserTest { private val announcer = mock(MdnsAnnouncer::class.java) private val prober = mock(MdnsProber::class.java) private val sharedlog = SharedLog("MdnsInterfaceAdvertiserTest") + private val flags = MdnsFeatureFlags.newBuilder().build() @Suppress("UNCHECKED_CAST") private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java) as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>> @@ -99,15 +100,14 @@ class MdnsInterfaceAdvertiserTest { cb, deps, TEST_HOSTNAME, - sharedlog + sharedlog, + flags ) } @Before fun setUp() { - doReturn(repository).`when`(deps).makeRecordRepository(any(), - eq(TEST_HOSTNAME) - ) + doReturn(repository).`when`(deps).makeRecordRepository(any(), eq(TEST_HOSTNAME), any()) doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any(), any()) doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any(), any()) doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any(), any()) diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index c9b502e35208fe880dc7565876bf2a1af07cb632..f26f7e178d9b5a8ffcca76f94cd775cc9a97753f 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -78,6 +78,7 @@ class MdnsRecordRepositoryTest { override fun getInterfaceInetAddresses(iface: NetworkInterface) = Collections.enumeration(TEST_ADDRESSES.map { it.address }) } + private val flags = MdnsFeatureFlags.newBuilder().build() @Before fun setUp() { @@ -92,7 +93,7 @@ class MdnsRecordRepositoryTest { @Test fun testAddServiceAndProbe() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) assertEquals(0, repository.servicesCount) assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)) @@ -127,7 +128,7 @@ class MdnsRecordRepositoryTest { @Test fun testAddAndConflicts() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) assertFailsWith(NameConflictException::class) { repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */) @@ -139,7 +140,7 @@ class MdnsRecordRepositoryTest { @Test fun testInvalidReuseOfServiceId() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) assertFailsWith(IllegalArgumentException::class) { repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */) @@ -148,7 +149,7 @@ class MdnsRecordRepositoryTest { @Test fun testHasActiveService() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1)) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) @@ -165,7 +166,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitAnnouncements() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) @@ -195,7 +196,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitAnnouncements_WithSubtype() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) @@ -231,7 +232,7 @@ class MdnsRecordRepositoryTest { @Test fun testExitingServiceReAdded() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) repository.exitService(TEST_SERVICE_ID_1) @@ -246,7 +247,7 @@ class MdnsRecordRepositoryTest { @Test fun testOnProbingSucceeded() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE) repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */) @@ -371,7 +372,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetOffloadPacket() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") val serviceType = arrayOf("_testservice", "_tcp", "local") @@ -433,7 +434,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetReplyCaseInsensitive() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) val questionsCaseInSensitive = listOf(MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"), @@ -463,7 +464,7 @@ class MdnsRecordRepositoryTest { } private fun doGetReplyTest(subtype: String?) { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype) val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local") else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local") @@ -551,7 +552,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServices() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) @@ -579,7 +580,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServicesCaseInsensitive() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) @@ -607,7 +608,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServices_IdenticalService() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) @@ -636,7 +637,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetConflictingServicesCaseInsensitive_IdenticalService() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */) @@ -665,7 +666,7 @@ class MdnsRecordRepositoryTest { @Test fun testGetServiceRepliedRequestsCount() { - val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME) + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) // Verify that there is no packet replied. assertEquals(MdnsConstants.NO_PACKET, @@ -690,6 +691,68 @@ class MdnsRecordRepositoryTest { assertEquals(MdnsConstants.NO_PACKET, repository.getServiceRepliedRequestsCount(TEST_SERVICE_ID_2)) } + + @Test + fun testIncludeInetAddressRecordsInProbing() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, + MdnsFeatureFlags.newBuilder().setIncludeInetAddressRecordsInProbing(true).build()) + repository.updateAddresses(TEST_ADDRESSES) + assertEquals(0, repository.servicesCount) + assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, + null /* subtype */)) + assertEquals(1, repository.servicesCount) + + val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1) + assertNotNull(probingInfo) + assertTrue(repository.isProbing(TEST_SERVICE_ID_1)) + + assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId) + val packet = probingInfo.getPacket(0) + + assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags) + assertEquals(0, packet.answers.size) + assertEquals(0, packet.additionalRecords.size) + + assertEquals(2, packet.questions.size) + val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + assertContentEquals(listOf( + MdnsAnyRecord(expectedName, false /* unicast */), + MdnsAnyRecord(TEST_HOSTNAME, false /* unicast */), + ), packet.questions) + + assertEquals(4, packet.authorityRecords.size) + assertContentEquals(listOf( + MdnsServiceRecord( + expectedName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 120_000L /* ttlMillis */, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_HOSTNAME), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 120_000L /* ttlMillis */, + TEST_ADDRESSES[0].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 120_000L /* ttlMillis */, + TEST_ADDRESSES[1].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + 120_000L /* ttlMillis */, + TEST_ADDRESSES[2].address) + ), packet.authorityRecords) + + assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices()) + } } private fun MdnsRecordRepository.initWithService(