From 1c2be6b391ef9d4740e00db77c4447a242a4b4a6 Mon Sep 17 00:00:00 2001
From: Handa Wang <handaw@google.com>
Date: Fri, 29 Dec 2023 14:53:58 +0800
Subject: [PATCH] [mdns] add API support for custom hostname

This commit adds support of regsitering a service with a custom
hostname. This is required to enable Advertising Proxy feature for Thread
devices.

For example:
```
NsdServiceInfo info = new NsdServiceInfo();

info.setServiceName("My Service");
info.setServiceType("_test._tcp");
info.setHostname("MyHost");
info.setHostAddresses(List.of(address1, address2));

nsdManager.registerService(info, PROTOCOL_DNS_SD, listener);
```

This CL also adds two new error codes for service/host registration
conflict: `FAILURE_SERVICE_NAME_CONFLICT` and
`FAILURE_HOST_NAME_CONFLICT`.

Bug: 284904661
Test: atest CtsNetTestCases FrameworksNetTests

Change-Id: I2fb974c38d851ba99144a18c7047084bbd61b938
---
 .../src/android/net/nsd/NsdManager.java       |  62 ++-
 .../src/android/net/nsd/NsdServiceInfo.java   |  46 +++
 .../android/net/nsd/NsdServiceInfoTest.java   |   3 +
 .../net/src/android/net/cts/NsdManagerTest.kt | 370 ++++++++++++++++++
 .../java/android/net/nsd/NsdManagerTest.java  |  45 +++
 5 files changed, 517 insertions(+), 9 deletions(-)

diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java
index 263acf27c9..27b4955fa9 100644
--- a/framework-t/src/android/net/nsd/NsdManager.java
+++ b/framework-t/src/android/net/nsd/NsdManager.java
@@ -159,6 +159,8 @@ public final class NsdManager {
                 "com.android.net.flags.nsd_subtypes_support_enabled";
         static final String ADVERTISE_REQUEST_API =
                 "com.android.net.flags.advertise_request_api";
+        static final String NSD_CUSTOM_HOSTNAME_ENABLED =
+                "com.android.net.flags.nsd_custom_hostname_enabled";
     }
 
     /**
@@ -1237,7 +1239,7 @@ public final class NsdManager {
      */
     public void registerService(@NonNull NsdServiceInfo serviceInfo, int protocolType,
             @NonNull Executor executor, @NonNull RegistrationListener listener) {
-        checkServiceInfo(serviceInfo);
+        checkServiceInfoForRegistration(serviceInfo);
         checkProtocol(protocolType);
         final AdvertisingRequest.Builder builder = new AdvertisingRequest.Builder(serviceInfo,
                 protocolType);
@@ -1296,7 +1298,10 @@ public final class NsdManager {
      * @return Type and comma-separated list of subtypes, or null if invalid format.
      */
     @Nullable
-    private static Pair<String, String> getTypeAndSubtypes(@NonNull String typeWithSubtype) {
+    private static Pair<String, String> getTypeAndSubtypes(@Nullable String typeWithSubtype) {
+        if (typeWithSubtype == null) {
+            return null;
+        }
         final Matcher matcher = Pattern.compile(TYPE_REGEX).matcher(typeWithSubtype);
         if (!matcher.matches()) return null;
         // Reject specifications using leading subtypes with a dot
@@ -1327,10 +1332,7 @@ public final class NsdManager {
             @NonNull RegistrationListener listener) {
         final NsdServiceInfo serviceInfo = advertisingRequest.getServiceInfo();
         final int protocolType = advertisingRequest.getProtocolType();
-        if (serviceInfo.getPort() <= 0) {
-            throw new IllegalArgumentException("Invalid port number");
-        }
-        checkServiceInfo(serviceInfo);
+        checkServiceInfoForRegistration(serviceInfo);
         checkProtocol(protocolType);
         final int key;
         // For update only request, the old listener has to be reused
@@ -1607,7 +1609,7 @@ public final class NsdManager {
     @Deprecated
     public void resolveService(@NonNull NsdServiceInfo serviceInfo,
             @NonNull Executor executor, @NonNull ResolveListener listener) {
-        checkServiceInfo(serviceInfo);
+        checkServiceInfoForResolution(serviceInfo);
         int key = putListener(listener, executor, serviceInfo);
         try {
             mService.resolveService(key, serviceInfo);
@@ -1661,7 +1663,7 @@ public final class NsdManager {
     // TODO: use {@link DiscoveryRequest} to specify the service to be subscribed
     public void registerServiceInfoCallback(@NonNull NsdServiceInfo serviceInfo,
             @NonNull Executor executor, @NonNull ServiceInfoCallback listener) {
-        checkServiceInfo(serviceInfo);
+        checkServiceInfoForResolution(serviceInfo);
         int key = putListener(listener, executor, serviceInfo);
         try {
             mService.registerServiceInfoCallback(key, serviceInfo);
@@ -1706,7 +1708,7 @@ public final class NsdManager {
         }
     }
 
-    private static void checkServiceInfo(NsdServiceInfo serviceInfo) {
+    private static void checkServiceInfoForResolution(NsdServiceInfo serviceInfo) {
         Objects.requireNonNull(serviceInfo, "NsdServiceInfo cannot be null");
         if (TextUtils.isEmpty(serviceInfo.getServiceName())) {
             throw new IllegalArgumentException("Service name cannot be empty");
@@ -1715,4 +1717,46 @@ public final class NsdManager {
             throw new IllegalArgumentException("Service type cannot be empty");
         }
     }
+
+    /**
+     * Check if the {@link NsdServiceInfo} is valid for registration.
+     *
+     * The following can be registered:
+     * - A service with an optional host.
+     * - A hostname with addresses.
+     *
+     * Note that:
+     * - When registering a service, the service name, service type and port must be specified. If
+     *   hostname is specified, the host addresses can optionally be specified.
+     * - When registering a host without a service, the addresses must be specified.
+     *
+     * @hide
+     */
+    public static void checkServiceInfoForRegistration(NsdServiceInfo serviceInfo) {
+        Objects.requireNonNull(serviceInfo, "NsdServiceInfo cannot be null");
+        boolean hasServiceName = !TextUtils.isEmpty(serviceInfo.getServiceName());
+        boolean hasServiceType = !TextUtils.isEmpty(serviceInfo.getServiceType());
+        boolean hasHostname = !TextUtils.isEmpty(serviceInfo.getHostname());
+        boolean hasHostAddresses = !CollectionUtils.isEmpty(serviceInfo.getHostAddresses());
+
+        if (serviceInfo.getPort() < 0) {
+            throw new IllegalArgumentException("Invalid port");
+        }
+
+        if (hasServiceType || hasServiceName || (serviceInfo.getPort() > 0)) {
+            if (!(hasServiceType && hasServiceName && (serviceInfo.getPort() > 0))) {
+                throw new IllegalArgumentException(
+                        "The service type, service name or port is missing");
+            }
+        }
+
+        if (!hasServiceType && !hasHostname) {
+            throw new IllegalArgumentException("No service or host specified in NsdServiceInfo");
+        }
+
+        if (!hasServiceType && hasHostname && !hasHostAddresses) {
+            // TODO: b/317946010 - This may be allowed when it supports registering KEY RR.
+            throw new IllegalArgumentException("No host addresses specified in NsdServiceInfo");
+        }
+    }
 }
diff --git a/framework-t/src/android/net/nsd/NsdServiceInfo.java b/framework-t/src/android/net/nsd/NsdServiceInfo.java
index ac4ea2318e..146d4cae30 100644
--- a/framework-t/src/android/net/nsd/NsdServiceInfo.java
+++ b/framework-t/src/android/net/nsd/NsdServiceInfo.java
@@ -49,8 +49,10 @@ public final class NsdServiceInfo implements Parcelable {
 
     private static final String TAG = "NsdServiceInfo";
 
+    @Nullable
     private String mServiceName;
 
+    @Nullable
     private String mServiceType;
 
     private final Set<String> mSubtypes;
@@ -59,6 +61,9 @@ public final class NsdServiceInfo implements Parcelable {
 
     private final List<InetAddress> mHostAddresses;
 
+    @Nullable
+    private String mHostname;
+
     private int mPort;
 
     @Nullable
@@ -90,6 +95,7 @@ public final class NsdServiceInfo implements Parcelable {
         mSubtypes = new ArraySet<>(other.getSubtypes());
         mTxtRecord = new ArrayMap<>(other.mTxtRecord);
         mHostAddresses = new ArrayList<>(other.getHostAddresses());
+        mHostname = other.getHostname();
         mPort = other.getPort();
         mNetwork = other.getNetwork();
         mInterfaceIndex = other.getInterfaceIndex();
@@ -168,6 +174,43 @@ public final class NsdServiceInfo implements Parcelable {
         mHostAddresses.addAll(addresses);
     }
 
+    /**
+     * Get the hostname.
+     *
+     * <p>When a service is resolved, it returns the hostname of the resolved service . The top
+     * level domain ".local." is omitted.
+     *
+     * <p>For example, it returns "MyHost" when the service's hostname is "MyHost.local.".
+     *
+     * @hide
+     */
+//    @FlaggedApi(NsdManager.Flags.NSD_CUSTOM_HOSTNAME_ENABLED)
+    @Nullable
+    public String getHostname() {
+        return mHostname;
+    }
+
+    /**
+     * Set a custom hostname for this service instance for registration.
+     *
+     * <p>A hostname must be in ".local." domain. The ".local." must be omitted when calling this
+     * method.
+     *
+     * <p>For example, you should call setHostname("MyHost") to use the hostname "MyHost.local.".
+     *
+     * <p>If a hostname is set with this method, the addresses set with {@link #setHostAddresses}
+     * will be registered with the hostname.
+     *
+     * <p>If the hostname is null (which is the default for a new {@link NsdServiceInfo}), a random
+     * hostname is used and the addresses of this device will be registered.
+     *
+     * @hide
+     */
+//    @FlaggedApi(NsdManager.Flags.NSD_CUSTOM_HOSTNAME_ENABLED)
+    public void setHostname(@Nullable String hostname) {
+        mHostname = hostname;
+    }
+
     /**
      * Unpack txt information from a base-64 encoded byte array.
      *
@@ -454,6 +497,7 @@ public final class NsdServiceInfo implements Parcelable {
                 .append(", type: ").append(mServiceType)
                 .append(", subtypes: ").append(TextUtils.join(", ", mSubtypes))
                 .append(", hostAddresses: ").append(TextUtils.join(", ", mHostAddresses))
+                .append(", hostname: ").append(mHostname)
                 .append(", port: ").append(mPort)
                 .append(", network: ").append(mNetwork);
 
@@ -494,6 +538,7 @@ public final class NsdServiceInfo implements Parcelable {
         for (InetAddress address : mHostAddresses) {
             InetAddressUtils.parcelInetAddress(dest, address, flags);
         }
+        dest.writeString(mHostname);
     }
 
     /** Implement the Parcelable interface */
@@ -523,6 +568,7 @@ public final class NsdServiceInfo implements Parcelable {
                 for (int i = 0; i < size; i++) {
                     info.mHostAddresses.add(InetAddressUtils.unparcelInetAddress(in));
                 }
+                info.mHostname = in.readString();
                 return info;
             }
 
diff --git a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
index 79c4980a5c..8e89037724 100644
--- a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
+++ b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
@@ -119,6 +119,7 @@ public class NsdServiceInfoTest {
         fullInfo.setSubtypes(Set.of("_thread", "_matter"));
         fullInfo.setPort(4242);
         fullInfo.setHostAddresses(List.of(IPV4_ADDRESS));
+        fullInfo.setHostname("home");
         fullInfo.setNetwork(new Network(123));
         fullInfo.setInterfaceIndex(456);
         checkParcelable(fullInfo);
@@ -134,6 +135,7 @@ public class NsdServiceInfoTest {
         attributedInfo.setServiceType("_kitten._tcp");
         attributedInfo.setPort(4242);
         attributedInfo.setHostAddresses(List.of(IPV6_ADDRESS, IPV4_ADDRESS));
+        attributedInfo.setHostname("home");
         attributedInfo.setAttribute("color", "pink");
         attributedInfo.setAttribute("sound", (new String("にゃあ")).getBytes("UTF-8"));
         attributedInfo.setAttribute("adorable", (String) null);
@@ -169,6 +171,7 @@ public class NsdServiceInfoTest {
         assertEquals(original.getServiceName(), result.getServiceName());
         assertEquals(original.getServiceType(), result.getServiceType());
         assertEquals(original.getHost(), result.getHost());
+        assertEquals(original.getHostname(), result.getHostname());
         assertTrue(original.getPort() == result.getPort());
         assertEquals(original.getNetwork(), result.getNetwork());
         assertEquals(original.getInterfaceIndex(), result.getInterfaceIndex());
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 8f9f8c7043..c368d5bc3e 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -127,6 +127,7 @@ import org.junit.Before
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
+import kotlin.test.assertNotEquals
 
 private const val TAG = "NsdManagerTest"
 private const val TIMEOUT_MS = 2000L
@@ -162,7 +163,11 @@ class NsdManagerTest {
     private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
     private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
     private val serviceName2 = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
+    private val serviceName3 = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
     private val serviceType = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
+    private val serviceType2 = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
+    private val customHostname = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
+    private val customHostname2 = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000))
     private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName)
     private val ctsNetUtils by lazy{ CtsNetUtils(context) }
 
@@ -1188,6 +1193,84 @@ class NsdManagerTest {
         }
     }
 
+    @Test
+    fun testRegisterServiceWithCustomHostAndAddresses_conflictDuringProbing_hostRenamed() {
+        val si = makeTestServiceInfo(testNetwork1.network).apply {
+            hostname = customHostname
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.24"),
+                    parseNumericAddress("2001:db8::3"))
+        }
+
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
+                registrationRecord)
+
+        tryTest {
+            assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
+                    "Did not find a probe for the service")
+            packetReader.sendResponse(buildConflictingAnnouncementForCustomHost())
+
+            // Registration must use an updated hostname to avoid the conflict
+            val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS)
+            // Service name is not renamed because there's no conflict on the service name.
+            // TODO: b/283053491 - enable this check
+//            assertEquals(serviceName, cb.serviceInfo.serviceName)
+            val hostname = cb.serviceInfo.hostname ?: fail("Missing hostname")
+            hostname.let {
+                assertTrue("Unexpected registered hostname: $it",
+                        it.startsWith(customHostname) && it != customHostname)
+            }
+        } cleanupStep {
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        } cleanup {
+            packetReader.handler.post { packetReader.stop() }
+            handlerThread.waitForIdle(TIMEOUT_MS)
+        }
+    }
+
+    @Test
+    fun testRegisterServiceWithCustomHostNoAddresses_noConflictDuringProbing_notRenamed() {
+        val si = makeTestServiceInfo(testNetwork1.network).apply {
+            hostname = customHostname
+        }
+
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
+                registrationRecord)
+
+        tryTest {
+            assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
+                    "Did not find a probe for the service")
+            // Not a conflict because no record is registered for the hostname
+            packetReader.sendResponse(buildConflictingAnnouncementForCustomHost())
+
+            // Registration is not renamed because there's no conflict
+            val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS)
+            assertEquals(serviceName, cb.serviceInfo.serviceName)
+            assertEquals(customHostname, cb.serviceInfo.hostname)
+        } cleanupStep {
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        } cleanup {
+            packetReader.handler.post { packetReader.stop() }
+            handlerThread.waitForIdle(TIMEOUT_MS)
+        }
+    }
+
     @Test
     fun testRegisterWithConflictAfterProbing() {
         // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
@@ -1263,6 +1346,52 @@ class NsdManagerTest {
         }
     }
 
+    // TODO: b/322282952 - Add the test case that the hostname is renamed due to a conflict after
+    //  probing succeeded.
+
+    @Test
+    fun testRegisterServiceWithCustomHostNoAddresses_noConflictAfterProbing_notRenamed() {
+        val si = makeTestServiceInfo(testNetwork1.network).apply {
+            hostname = customHostname
+        }
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        val discoveryRecord = NsdDiscoveryRecord()
+        val registeredService = registerService(registrationRecord, si)
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        tryTest {
+            assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType),
+                    "No announcements sent after initial probing")
+
+            assertEquals(si.serviceName, registeredService.serviceName)
+            assertEquals(si.hostname, registeredService.hostname)
+
+            // Send a conflicting announcement
+            val conflictingAnnouncement = buildConflictingAnnouncementForCustomHost()
+            packetReader.sendResponse(conflictingAnnouncement)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, { it.run() }, discoveryRecord)
+
+            // The service is not renamed
+            discoveryRecord.waitForServiceDiscovered(si.serviceName, serviceType)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        } cleanupStep {
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        } cleanup {
+            packetReader.handler.post { packetReader.stop() }
+            handlerThread.waitForIdle(TIMEOUT_MS)
+        }
+    }
+
     // Test that even if only a PTR record is received as a reply when discovering, without the
     // SRV, TXT, address records as recommended (but not mandated) by RFC 6763 12, the service can
     // still be discovered.
@@ -1447,6 +1576,212 @@ class NsdManagerTest {
         return Inet6Address.getByAddress(addrBytes) as Inet6Address
     }
 
+    @Test
+    fun testAdvertisingAndDiscovery_servicesWithCustomHost_customHostAddressesFound() {
+        val hostAddresses1 = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2"))
+        val hostAddresses2 = listOf(
+                parseNumericAddress("192.0.2.24"),
+                parseNumericAddress("2001:db8::3"))
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName
+            it.serviceType = serviceType
+            it.port = TEST_PORT
+            it.hostname = customHostname
+            it.hostAddresses = hostAddresses1
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName2
+            it.serviceType = serviceType
+            it.port = TEST_PORT + 1
+            it.hostname = customHostname2
+            it.hostAddresses = hostAddresses2
+        }
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+
+        val discoveryRecord1 = NsdDiscoveryRecord()
+        val discoveryRecord2 = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord1, si1)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord1)
+
+            val discoveredInfo = discoveryRecord1.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo = resolveService(discoveredInfo)
+
+            assertEquals(TEST_PORT, resolvedInfo.port)
+            assertEquals(si1.hostname, resolvedInfo.hostname)
+            assertAddressEquals(hostAddresses1, resolvedInfo.hostAddresses)
+
+            registerService(registrationRecord2, si2)
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord2)
+
+            val discoveredInfo2 = discoveryRecord2.waitForServiceDiscovered(
+                    serviceName2, serviceType, testNetwork1.network)
+            val resolvedInfo2 = resolveService(discoveredInfo2)
+
+            assertEquals(TEST_PORT + 1, resolvedInfo2.port)
+            assertEquals(si2.hostname, resolvedInfo2.hostname)
+            assertAddressEquals(hostAddresses2, resolvedInfo2.hostAddresses)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord1)
+            nsdManager.stopServiceDiscovery(discoveryRecord2)
+
+            discoveryRecord1.expectCallbackEventually<DiscoveryStopped>()
+            discoveryRecord2.expectCallbackEventually<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord1)
+            nsdManager.unregisterService(registrationRecord2)
+        }
+    }
+
+    @Test
+    fun testAdvertisingAndDiscovery_multipleRegistrationsForSameCustomHost_unionOfAddressesFound() {
+        val hostAddresses1 = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2"))
+        val hostAddresses2 = listOf(
+                parseNumericAddress("192.0.2.24"),
+                parseNumericAddress("2001:db8::3"))
+        val hostAddresses3 = listOf(
+                parseNumericAddress("2001:db8::3"),
+                parseNumericAddress("2001:db8::5"))
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = hostAddresses1
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName
+            it.serviceType = serviceType
+            it.port = TEST_PORT
+            it.hostname = customHostname
+            it.hostAddresses = hostAddresses2
+        }
+        val si3 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName3
+            it.serviceType = serviceType
+            it.port = TEST_PORT + 1
+            it.hostname = customHostname
+            it.hostAddresses = hostAddresses3
+        }
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        val registrationRecord3 = NsdRegistrationRecord()
+
+        val discoveryRecord = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord1, si1)
+            registerService(registrationRecord2, si2)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+
+            val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo1 = resolveService(discoveredInfo1)
+
+            assertEquals(TEST_PORT, resolvedInfo1.port)
+            assertEquals(si1.hostname, resolvedInfo1.hostname)
+            assertAddressEquals(
+                    hostAddresses1 + hostAddresses2,
+                    resolvedInfo1.hostAddresses)
+
+            registerService(registrationRecord3, si3)
+
+            val discoveredInfo2 = discoveryRecord.waitForServiceDiscovered(
+                    serviceName3, serviceType, testNetwork1.network)
+            val resolvedInfo2 = resolveService(discoveredInfo2)
+
+            assertEquals(TEST_PORT + 1, resolvedInfo2.port)
+            assertEquals(si2.hostname, resolvedInfo2.hostname)
+            assertAddressEquals(
+                    hostAddresses1 + hostAddresses2 + hostAddresses3,
+                    resolvedInfo2.hostAddresses)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+
+            discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord1)
+            nsdManager.unregisterService(registrationRecord2)
+            nsdManager.unregisterService(registrationRecord3)
+        }
+    }
+
+    @Test
+    fun testAdvertisingAndDiscovery_servicesWithTheSameCustomHostAddressOmitted_addressesFound() {
+        val hostAddresses = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2"))
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType
+            it.serviceName = serviceName
+            it.port = TEST_PORT
+            it.hostname = customHostname
+            it.hostAddresses = hostAddresses
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceType = serviceType
+            it.serviceName = serviceName2
+            it.port = TEST_PORT + 1
+            it.hostname = customHostname
+        }
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+
+        val discoveryRecord = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord1, si1)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+
+            val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo1 = resolveService(discoveredInfo1)
+
+            assertEquals(serviceName, discoveredInfo1.serviceName)
+            assertEquals(TEST_PORT, resolvedInfo1.port)
+            assertEquals(si1.hostname, resolvedInfo1.hostname)
+            assertAddressEquals(hostAddresses, resolvedInfo1.hostAddresses)
+
+            registerService(registrationRecord2, si2)
+
+            val discoveredInfo2 = discoveryRecord.waitForServiceDiscovered(
+                    serviceName2, serviceType, testNetwork1.network)
+            val resolvedInfo2 = resolveService(discoveredInfo2)
+
+            assertEquals(serviceName2, discoveredInfo2.serviceName)
+            assertEquals(TEST_PORT + 1, resolvedInfo2.port)
+            assertEquals(si2.hostname, resolvedInfo2.hostname)
+            assertAddressEquals(hostAddresses, resolvedInfo2.hostAddresses)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord1)
+            nsdManager.unregisterService(registrationRecord2)
+        }
+    }
+
     private fun buildConflictingAnnouncement(): ByteBuffer {
         /*
         Generated with:
@@ -1463,6 +1798,22 @@ class NsdManagerTest {
         return buildMdnsPacket(mdnsPayload)
     }
 
+    private fun buildConflictingAnnouncementForCustomHost(): ByteBuffer {
+        /*
+        Generated with scapy:
+        raw(DNS(rd=0, qr=1, aa=1, qd = None, an =
+            DNSRR(rrname='NsdTestHost123456789.local', type=28, rclass=1, ttl=120,
+                    rdata='2001:db8::321')
+        )).hex()
+         */
+        val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000144e7364" +
+                "54657374486f7374313233343536373839056c6f63616c00001c000100000078001020010db80000" +
+                "00000000000000000321")
+        replaceCustomHostnameWithTestSuffix(mdnsPayload)
+
+        return buildMdnsPacket(mdnsPayload)
+    }
+
     /**
      * Replaces occurrences of "NsdTest123456789" and "_nmt123456789" in mDNS payload with the
      * actual random name and type that are used by the test.
@@ -1479,6 +1830,19 @@ class NsdManagerTest {
         replaceAll(packetBuffer, testPacketTypePrefix, encodedTypePrefix)
     }
 
+    /**
+     * Replaces occurrences of "NsdTestHost123456789" in mDNS payload with the
+     * actual random host name that are used by the test.
+     */
+    private fun replaceCustomHostnameWithTestSuffix(mdnsPayload: ByteArray) {
+        // Test custom hostnames have consistent length and are always ASCII
+        val testPacketName = "NsdTestHost123456789".encodeToByteArray()
+        val encodedHostname = customHostname.encodeToByteArray()
+
+        val packetBuffer = ByteBuffer.wrap(mdnsPayload)
+        replaceAll(packetBuffer, testPacketName, encodedHostname)
+    }
+
     private tailrec fun replaceAll(buffer: ByteBuffer, source: ByteArray, replacement: ByteArray) {
         assertEquals(source.size, replacement.size)
         val index = buffer.array().indexOf(source)
@@ -1577,3 +1941,9 @@ private fun ByteArray?.utf8ToString(): String {
     if (this == null) return ""
     return String(this, StandardCharsets.UTF_8)
 }
+
+private fun assertAddressEquals(expected: List<InetAddress>, actual: List<InetAddress>) {
+    // No duplicate addresses in the actual address list
+    assertEquals(actual.toSet().size, actual.size)
+    assertEquals(expected.toSet(), actual.toSet())
+}
\ No newline at end of file
diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java
index aabe8d3423..951675ca79 100644
--- a/tests/unit/java/android/net/nsd/NsdManagerTest.java
+++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java
@@ -52,6 +52,9 @@ import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import java.net.InetAddress;
+import java.util.List;
+
 @DevSdkIgnoreRunner.MonitorThreadLeak
 @RunWith(DevSdkIgnoreRunner.class)
 @SmallTest
@@ -370,6 +373,9 @@ public class NsdManagerTest {
         NsdManager.RegistrationListener listener1 = mock(NsdManager.RegistrationListener.class);
         NsdManager.DiscoveryListener listener2 = mock(NsdManager.DiscoveryListener.class);
         NsdManager.ResolveListener listener3 = mock(NsdManager.ResolveListener.class);
+        NsdManager.RegistrationListener listener4 = mock(NsdManager.RegistrationListener.class);
+        NsdManager.RegistrationListener listener5 = mock(NsdManager.RegistrationListener.class);
+        NsdManager.RegistrationListener listener6 = mock(NsdManager.RegistrationListener.class);
 
         NsdServiceInfo invalidService = new NsdServiceInfo(null, null);
         NsdServiceInfo validService = new NsdServiceInfo("a_name", "_a_type._tcp");
@@ -379,6 +385,7 @@ public class NsdManagerTest {
                 "_a_type._tcp,_sub1,_s2");
         NsdServiceInfo otherSubtypeUpdate = new NsdServiceInfo("a_name", "_a_type._tcp,_sub1,_s3");
         NsdServiceInfo dotSyntaxSubtypeUpdate = new NsdServiceInfo("a_name", "_sub1._a_type._tcp");
+
         validService.setPort(2222);
         otherServiceWithSubtype.setPort(2222);
         validServiceDuplicate.setPort(2222);
@@ -386,6 +393,33 @@ public class NsdManagerTest {
         otherSubtypeUpdate.setPort(2222);
         dotSyntaxSubtypeUpdate.setPort(2222);
 
+        NsdServiceInfo invalidMissingHostnameWithAddresses = new NsdServiceInfo(null, null);
+        invalidMissingHostnameWithAddresses.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("192.168.82.14"),
+                        InetAddress.parseNumericAddress("2001::1")));
+
+        NsdServiceInfo validCustomHostWithAddresses = new NsdServiceInfo(null, null);
+        validCustomHostWithAddresses.setHostname("a_host");
+        validCustomHostWithAddresses.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("192.168.82.14"),
+                        InetAddress.parseNumericAddress("2001::1")));
+
+        NsdServiceInfo validServiceWithCustomHostAndAddresses =
+                new NsdServiceInfo("a_name", "_a_type._tcp");
+        validServiceWithCustomHostAndAddresses.setPort(2222);
+        validServiceWithCustomHostAndAddresses.setHostname("a_host");
+        validServiceWithCustomHostAndAddresses.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("192.168.82.14"),
+                        InetAddress.parseNumericAddress("2001::1")));
+
+        NsdServiceInfo validServiceWithCustomHostNoAddresses =
+                new NsdServiceInfo("a_name", "_a_type._tcp");
+        validServiceWithCustomHostNoAddresses.setPort(2222);
+        validServiceWithCustomHostNoAddresses.setHostname("a_host");
+
         // Service registration
         //  - invalid arguments
         mustFail(() -> { manager.unregisterService(null); });
@@ -394,6 +428,8 @@ public class NsdManagerTest {
         mustFail(() -> { manager.registerService(invalidService, PROTOCOL, listener1); });
         mustFail(() -> { manager.registerService(validService, -1, listener1); });
         mustFail(() -> { manager.registerService(validService, PROTOCOL, null); });
+        mustFail(() -> {
+            manager.registerService(invalidMissingHostnameWithAddresses, PROTOCOL, listener1); });
         manager.registerService(validService, PROTOCOL, listener1);
         //  - update without subtype is not allowed
         mustFail(() -> { manager.registerService(validServiceDuplicate, PROTOCOL, listener1); });
@@ -415,6 +451,15 @@ public class NsdManagerTest {
         // TODO: make listener immediately reusable
         //mustFail(() -> { manager.unregisterService(listener1); });
         //manager.registerService(validService, PROTOCOL, listener1);
+        //  - registering a custom host without a service is valid
+        manager.registerService(validCustomHostWithAddresses, PROTOCOL, listener4);
+        manager.unregisterService(listener4);
+        //  - registering a service with a custom host is valid
+        manager.registerService(validServiceWithCustomHostAndAddresses, PROTOCOL, listener5);
+        manager.unregisterService(listener5);
+        //  - registering a service with a custom host with no addresses is valid
+        manager.registerService(validServiceWithCustomHostNoAddresses, PROTOCOL, listener6);
+        manager.unregisterService(listener6);
 
         // Discover service
         //  - invalid arguments
-- 
GitLab