diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java index 263acf27c99cd03c0bfe33c0a6125d2880a55a10..27b4955fa99021b802c0a526d64d1e6d5a0f6a59 100644 --- a/framework-t/src/android/net/nsd/NsdManager.java +++ b/framework-t/src/android/net/nsd/NsdManager.java @@ -159,6 +159,8 @@ public final class NsdManager { "com.android.net.flags.nsd_subtypes_support_enabled"; static final String ADVERTISE_REQUEST_API = "com.android.net.flags.advertise_request_api"; + static final String NSD_CUSTOM_HOSTNAME_ENABLED = + "com.android.net.flags.nsd_custom_hostname_enabled"; } /** @@ -1237,7 +1239,7 @@ public final class NsdManager { */ public void registerService(@NonNull NsdServiceInfo serviceInfo, int protocolType, @NonNull Executor executor, @NonNull RegistrationListener listener) { - checkServiceInfo(serviceInfo); + checkServiceInfoForRegistration(serviceInfo); checkProtocol(protocolType); final AdvertisingRequest.Builder builder = new AdvertisingRequest.Builder(serviceInfo, protocolType); @@ -1296,7 +1298,10 @@ public final class NsdManager { * @return Type and comma-separated list of subtypes, or null if invalid format. */ @Nullable - private static Pair<String, String> getTypeAndSubtypes(@NonNull String typeWithSubtype) { + private static Pair<String, String> getTypeAndSubtypes(@Nullable String typeWithSubtype) { + if (typeWithSubtype == null) { + return null; + } final Matcher matcher = Pattern.compile(TYPE_REGEX).matcher(typeWithSubtype); if (!matcher.matches()) return null; // Reject specifications using leading subtypes with a dot @@ -1327,10 +1332,7 @@ public final class NsdManager { @NonNull RegistrationListener listener) { final NsdServiceInfo serviceInfo = advertisingRequest.getServiceInfo(); final int protocolType = advertisingRequest.getProtocolType(); - if (serviceInfo.getPort() <= 0) { - throw new IllegalArgumentException("Invalid port number"); - } - checkServiceInfo(serviceInfo); + checkServiceInfoForRegistration(serviceInfo); checkProtocol(protocolType); final int key; // For update only request, the old listener has to be reused @@ -1607,7 +1609,7 @@ public final class NsdManager { @Deprecated public void resolveService(@NonNull NsdServiceInfo serviceInfo, @NonNull Executor executor, @NonNull ResolveListener listener) { - checkServiceInfo(serviceInfo); + checkServiceInfoForResolution(serviceInfo); int key = putListener(listener, executor, serviceInfo); try { mService.resolveService(key, serviceInfo); @@ -1661,7 +1663,7 @@ public final class NsdManager { // TODO: use {@link DiscoveryRequest} to specify the service to be subscribed public void registerServiceInfoCallback(@NonNull NsdServiceInfo serviceInfo, @NonNull Executor executor, @NonNull ServiceInfoCallback listener) { - checkServiceInfo(serviceInfo); + checkServiceInfoForResolution(serviceInfo); int key = putListener(listener, executor, serviceInfo); try { mService.registerServiceInfoCallback(key, serviceInfo); @@ -1706,7 +1708,7 @@ public final class NsdManager { } } - private static void checkServiceInfo(NsdServiceInfo serviceInfo) { + private static void checkServiceInfoForResolution(NsdServiceInfo serviceInfo) { Objects.requireNonNull(serviceInfo, "NsdServiceInfo cannot be null"); if (TextUtils.isEmpty(serviceInfo.getServiceName())) { throw new IllegalArgumentException("Service name cannot be empty"); @@ -1715,4 +1717,46 @@ public final class NsdManager { throw new IllegalArgumentException("Service type cannot be empty"); } } + + /** + * Check if the {@link NsdServiceInfo} is valid for registration. + * + * The following can be registered: + * - A service with an optional host. + * - A hostname with addresses. + * + * Note that: + * - When registering a service, the service name, service type and port must be specified. If + * hostname is specified, the host addresses can optionally be specified. + * - When registering a host without a service, the addresses must be specified. + * + * @hide + */ + public static void checkServiceInfoForRegistration(NsdServiceInfo serviceInfo) { + Objects.requireNonNull(serviceInfo, "NsdServiceInfo cannot be null"); + boolean hasServiceName = !TextUtils.isEmpty(serviceInfo.getServiceName()); + boolean hasServiceType = !TextUtils.isEmpty(serviceInfo.getServiceType()); + boolean hasHostname = !TextUtils.isEmpty(serviceInfo.getHostname()); + boolean hasHostAddresses = !CollectionUtils.isEmpty(serviceInfo.getHostAddresses()); + + if (serviceInfo.getPort() < 0) { + throw new IllegalArgumentException("Invalid port"); + } + + if (hasServiceType || hasServiceName || (serviceInfo.getPort() > 0)) { + if (!(hasServiceType && hasServiceName && (serviceInfo.getPort() > 0))) { + throw new IllegalArgumentException( + "The service type, service name or port is missing"); + } + } + + if (!hasServiceType && !hasHostname) { + throw new IllegalArgumentException("No service or host specified in NsdServiceInfo"); + } + + if (!hasServiceType && hasHostname && !hasHostAddresses) { + // TODO: b/317946010 - This may be allowed when it supports registering KEY RR. + throw new IllegalArgumentException("No host addresses specified in NsdServiceInfo"); + } + } } diff --git a/framework-t/src/android/net/nsd/NsdServiceInfo.java b/framework-t/src/android/net/nsd/NsdServiceInfo.java index ac4ea2318ee01f2f13353d55c628b801b6f2137d..146d4cae30327188a921a0bde601e2d0ef9852a8 100644 --- a/framework-t/src/android/net/nsd/NsdServiceInfo.java +++ b/framework-t/src/android/net/nsd/NsdServiceInfo.java @@ -49,8 +49,10 @@ public final class NsdServiceInfo implements Parcelable { private static final String TAG = "NsdServiceInfo"; + @Nullable private String mServiceName; + @Nullable private String mServiceType; private final Set<String> mSubtypes; @@ -59,6 +61,9 @@ public final class NsdServiceInfo implements Parcelable { private final List<InetAddress> mHostAddresses; + @Nullable + private String mHostname; + private int mPort; @Nullable @@ -90,6 +95,7 @@ public final class NsdServiceInfo implements Parcelable { mSubtypes = new ArraySet<>(other.getSubtypes()); mTxtRecord = new ArrayMap<>(other.mTxtRecord); mHostAddresses = new ArrayList<>(other.getHostAddresses()); + mHostname = other.getHostname(); mPort = other.getPort(); mNetwork = other.getNetwork(); mInterfaceIndex = other.getInterfaceIndex(); @@ -168,6 +174,43 @@ public final class NsdServiceInfo implements Parcelable { mHostAddresses.addAll(addresses); } + /** + * Get the hostname. + * + * <p>When a service is resolved, it returns the hostname of the resolved service . The top + * level domain ".local." is omitted. + * + * <p>For example, it returns "MyHost" when the service's hostname is "MyHost.local.". + * + * @hide + */ +// @FlaggedApi(NsdManager.Flags.NSD_CUSTOM_HOSTNAME_ENABLED) + @Nullable + public String getHostname() { + return mHostname; + } + + /** + * Set a custom hostname for this service instance for registration. + * + * <p>A hostname must be in ".local." domain. The ".local." must be omitted when calling this + * method. + * + * <p>For example, you should call setHostname("MyHost") to use the hostname "MyHost.local.". + * + * <p>If a hostname is set with this method, the addresses set with {@link #setHostAddresses} + * will be registered with the hostname. + * + * <p>If the hostname is null (which is the default for a new {@link NsdServiceInfo}), a random + * hostname is used and the addresses of this device will be registered. + * + * @hide + */ +// @FlaggedApi(NsdManager.Flags.NSD_CUSTOM_HOSTNAME_ENABLED) + public void setHostname(@Nullable String hostname) { + mHostname = hostname; + } + /** * Unpack txt information from a base-64 encoded byte array. * @@ -454,6 +497,7 @@ public final class NsdServiceInfo implements Parcelable { .append(", type: ").append(mServiceType) .append(", subtypes: ").append(TextUtils.join(", ", mSubtypes)) .append(", hostAddresses: ").append(TextUtils.join(", ", mHostAddresses)) + .append(", hostname: ").append(mHostname) .append(", port: ").append(mPort) .append(", network: ").append(mNetwork); @@ -494,6 +538,7 @@ public final class NsdServiceInfo implements Parcelable { for (InetAddress address : mHostAddresses) { InetAddressUtils.parcelInetAddress(dest, address, flags); } + dest.writeString(mHostname); } /** Implement the Parcelable interface */ @@ -523,6 +568,7 @@ public final class NsdServiceInfo implements Parcelable { for (int i = 0; i < size; i++) { info.mHostAddresses.add(InetAddressUtils.unparcelInetAddress(in)); } + info.mHostname = in.readString(); return info; } diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index be0b67353a231f962a7500af697a545959abc8e9..397e5a63d75131739208584b91f9fee5327f920e 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -895,10 +895,18 @@ public class NsdService extends INsdManager.Stub { serviceType); final String registerServiceType = typeSubtype == null ? null : typeSubtype.first; + final String hostname = serviceInfo.getHostname(); + // Keep compatible with the legacy behavior: It's allowed to set host + // addresses for a service registration although the host addresses + // won't be registered. To register the addresses for a host, the + // hostname must be specified. + if (hostname == null) { + serviceInfo.setHostAddresses(Collections.emptyList()); + } if (clientInfo.mUseJavaBackend || mDeps.isMdnsAdvertiserEnabled(mContext) || useAdvertiserForType(registerServiceType)) { - if (registerServiceType == null) { + if (serviceType != null && registerServiceType == null) { Log.e(TAG, "Invalid service type: " + serviceType); clientInfo.onRegisterServiceFailedImmediately(clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR, false /* isLegacy */); @@ -921,14 +929,25 @@ public class NsdService extends INsdManager.Stub { } else { transactionId = getUniqueId(); } - serviceInfo.setServiceType(registerServiceType); - serviceInfo.setServiceName(truncateServiceName( - serviceInfo.getServiceName())); + + if (registerServiceType != null) { + serviceInfo.setServiceType(registerServiceType); + serviceInfo.setServiceName( + truncateServiceName(serviceInfo.getServiceName())); + } + + if (!checkHostname(hostname)) { + clientInfo.onRegisterServiceFailedImmediately(clientRequestId, + NsdManager.FAILURE_BAD_PARAMETERS, false /* isLegacy */); + break; + } Set<String> subtypes = new ArraySet<>(serviceInfo.getSubtypes()); - for (String subType: typeSubtype.second) { - if (!TextUtils.isEmpty(subType)) { - subtypes.add(subType); + if (typeSubtype != null && typeSubtype.second != null) { + for (String subType : typeSubtype.second) { + if (!TextUtils.isEmpty(subType)) { + subtypes.add(subType); + } } } subtypes = dedupSubtypeLabels(subtypes); @@ -945,7 +964,7 @@ public class NsdService extends INsdManager.Stub { MdnsAdvertisingOptions.newBuilder().setIsOnlyUpdate( isUpdateOnly).build(); mAdvertiser.addOrUpdateService(transactionId, serviceInfo, - mdnsAdvertisingOptions); + mdnsAdvertisingOptions, clientInfo.mUid); storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo, serviceInfo.getNetwork()); } else { @@ -1535,6 +1554,7 @@ public class NsdService extends INsdManager.Stub { Log.e(TAG, "Invalid attribute", e); } } + info.setHostname(getHostname(serviceInfo)); final List<InetAddress> addresses = getInetAddresses(serviceInfo); if (addresses.size() != 0) { info.setHostAddresses(addresses); @@ -1571,6 +1591,7 @@ public class NsdService extends INsdManager.Stub { } } + info.setHostname(getHostname(serviceInfo)); final List<InetAddress> addresses = getInetAddresses(serviceInfo); info.setHostAddresses(addresses); clientInfo.onServiceUpdated(clientRequestId, info, request); @@ -1617,6 +1638,16 @@ public class NsdService extends INsdManager.Stub { return addresses; } + @NonNull + private static String getHostname(@NonNull MdnsServiceInfo serviceInfo) { + String[] hostname = serviceInfo.getHostName(); + // Strip the "local" top-level domain. + if (hostname.length >= 2 && hostname[hostname.length - 1].equals("local")) { + hostname = Arrays.copyOf(hostname, hostname.length - 1); + } + return String.join(".", hostname); + } + private static void setServiceNetworkForCallback(NsdServiceInfo info, int netId, int ifaceIdx) { switch (netId) { case NETID_UNSET: @@ -1702,6 +1733,21 @@ public class NsdService extends INsdManager.Stub { return new Pair<>(queryType, Collections.emptyList()); } + /** + * Checks if the hostname is valid. + * + * <p>For now NsdService only allows single-label hostnames conforming to RFC 1035. In other + * words, the hostname should be at most 63 characters long and it only contains letters, digits + * and hyphens. + */ + public static boolean checkHostname(@Nullable String hostname) { + if (hostname == null) { + return true; + } + String HOSTNAME_REGEX = "^[a-zA-Z]([a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$"; + return Pattern.compile(HOSTNAME_REGEX).matcher(hostname).matches(); + } + /** Returns {@code true} if {@code subtype} is a valid DNS-SD subtype label. */ private static boolean checkSubtypeLabel(String subtype) { return Pattern.compile("^" + TYPE_SUBTYPE_LABEL_REGEX + "$").matcher(subtype).matches(); @@ -2031,9 +2077,10 @@ public class NsdService extends INsdManager.Stub { final int clientRequestId = getClientRequestIdOrLog(clientInfo, transactionId); if (clientRequestId < 0) return; - // onRegisterServiceSucceeded only has the service name in its info. This aligns with - // historical behavior. + // onRegisterServiceSucceeded only has the service name and hostname in its info. This + // aligns with historical behavior. final NsdServiceInfo cbInfo = new NsdServiceInfo(registeredInfo.getServiceName(), null); + cbInfo.setHostname(registeredInfo.getHostname()); final ClientRequest request = clientInfo.mClientRequests.get(clientRequestId); clientInfo.onRegisterServiceSucceeded(clientRequestId, cbInfo, request); } @@ -2143,6 +2190,7 @@ public class NsdService extends INsdManager.Stub { @Override public void registerService(int listenerKey, AdvertisingRequest advertisingRequest) throws RemoteException { + NsdManager.checkServiceInfoForRegistration(advertisingRequest.getServiceInfo()); mNsdStateMachine.sendMessage(mNsdStateMachine.obtainMessage( NsdManager.REGISTER_SERVICE, 0, listenerKey, new AdvertisingArgs(this, advertisingRequest) diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java index b2af93cdb45b6f225c67299bfb2d53413c746c9e..60859f8c106204d6c25e4414c9e4a419f55135f2 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java @@ -17,6 +17,8 @@ package com.android.server.connectivity.mdns; import static com.android.server.connectivity.mdns.MdnsConstants.NO_PACKET; +import static com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST; +import static com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE; import static com.android.server.connectivity.mdns.MdnsRecord.MAX_LABEL_LENGTH; import android.annotation.NonNull; @@ -31,6 +33,7 @@ import android.net.nsd.OffloadEngine; import android.net.nsd.OffloadServiceInfo; import android.os.Build; import android.os.Looper; +import android.text.TextUtils; import android.util.ArrayMap; import android.util.Log; import android.util.SparseArray; @@ -151,7 +154,9 @@ public class MdnsAdvertiser { mSharedLog.wtf("Register succeeded for unknown registration"); return; } - if (mMdnsFeatureFlags.mIsMdnsOffloadFeatureEnabled) { + if (mMdnsFeatureFlags.mIsMdnsOffloadFeatureEnabled + // TODO: Enable offload when the serviceInfo contains a custom host. + && TextUtils.isEmpty(registration.getServiceInfo().getHostname())) { final String interfaceName = advertiser.getSocketInterfaceName(); final List<OffloadServiceInfoWrapper> existingOffloadServiceInfoWrappers = mInterfaceOffloadServices.computeIfAbsent(interfaceName, @@ -179,8 +184,11 @@ public class MdnsAdvertiser { } @Override - public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) { - mSharedLog.i("Found conflict, restarted probing for service " + serviceId); + public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId, + int conflictType) { + mSharedLog.i("Found conflict, restarted probing for service " + + serviceId + " " + + conflictType); final Registration registration = mRegistrations.get(serviceId); if (registration == null) return; @@ -205,10 +213,22 @@ public class MdnsAdvertiser { return; } - // Conflict was found during probing; rename once to find a name that has no conflict - registration.updateForConflict( - registration.makeNewServiceInfoForConflict(1 /* renameCount */), - 1 /* renameCount */); + if ((conflictType & CONFLICT_SERVICE) != 0) { + // Service conflict was found during probing; rename once to find a name that has no + // conflict + registration.updateForServiceConflict( + registration.makeNewServiceInfoForServiceConflict(1 /* renameCount */), + 1 /* renameCount */); + } + + if ((conflictType & CONFLICT_HOST) != 0) { + // Host conflict was found during probing; rename once to find a name that has no + // conflict + registration.updateForHostConflict( + registration.makeNewServiceInfoForHostConflict(1 /* renameCount */), + 1 /* renameCount */); + } + registration.mConflictDuringProbingCount++; // Keep renaming if the new name conflicts in local registrations @@ -231,23 +251,53 @@ public class MdnsAdvertiser { } }; - private boolean hasAnyConflict( + private boolean hasAnyServiceConflict( @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter, @NonNull NsdServiceInfo newInfo) { - return any(mAdvertiserRequests, (network, adv) -> - applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo)); + return any( + mAdvertiserRequests, + (network, adv) -> + applicableAdvertiserFilter.test(network, adv) + && adv.hasServiceConflict(newInfo)); + } + + private boolean hasAnyHostConflict( + @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter, + @NonNull NsdServiceInfo newInfo, + int clientUid) { + // Check if it conflicts with custom hosts. + if (any( + mAdvertiserRequests, + (network, adv) -> + applicableAdvertiserFilter.test(network, adv) + && adv.hasHostConflict(newInfo, clientUid))) { + return true; + } + // Check if it conflicts with the default hostname. + return MdnsUtils.equalsIgnoreDnsCase(newInfo.getHostname(), mDeviceHostName[0]); } private void updateRegistrationUntilNoConflict( @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter, @NonNull Registration registration) { - int renameCount = 0; NsdServiceInfo newInfo = registration.getServiceInfo(); - while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) { - renameCount++; - newInfo = registration.makeNewServiceInfoForConflict(renameCount); + + int renameServiceCount = 0; + while (hasAnyServiceConflict(applicableAdvertiserFilter, newInfo)) { + renameServiceCount++; + newInfo = registration.makeNewServiceInfoForServiceConflict(renameServiceCount); + } + registration.updateForServiceConflict(newInfo, renameServiceCount); + + if (!TextUtils.isEmpty(registration.getServiceInfo().getHostname())) { + int renameHostCount = 0; + while (hasAnyHostConflict( + applicableAdvertiserFilter, newInfo, registration.mClientUid)) { + renameHostCount++; + newInfo = registration.makeNewServiceInfoForHostConflict(renameHostCount); + } + registration.updateForHostConflict(newInfo, renameHostCount); } - registration.updateForConflict(newInfo, renameCount); } private void maybeSendOffloadStop(final String interfaceName, int serviceId) { @@ -326,16 +376,27 @@ public class MdnsAdvertiser { /** * Return whether using the proposed new {@link NsdServiceInfo} to add a registration would - * cause a conflict in this {@link InterfaceAdvertiserRequest}. + * cause a conflict of the service in this {@link InterfaceAdvertiserRequest}. */ - boolean hasConflict(@NonNull NsdServiceInfo newInfo) { - return getConflictingService(newInfo) >= 0; + boolean hasServiceConflict(@NonNull NsdServiceInfo newInfo) { + return getConflictingRegistrationDueToService(newInfo) >= 0; } /** - * Get the ID of a conflicting service, or -1 if none. + * Return whether using the proposed new {@link NsdServiceInfo} to add a registration would + * cause a conflict of the host in this {@link InterfaceAdvertiserRequest}. + * + * @param clientUid UID of the user who wants to advertise the serviceInfo. */ - int getConflictingService(@NonNull NsdServiceInfo info) { + boolean hasHostConflict(@NonNull NsdServiceInfo newInfo, int clientUid) { + return getConflictingRegistrationDueToHost(newInfo, clientUid) >= 0; + } + + /** Get the ID of a conflicting registration due to service, or -1 if none. */ + int getConflictingRegistrationDueToService(@NonNull NsdServiceInfo info) { + if (TextUtils.isEmpty(info.getServiceName())) { + return -1; + } for (int i = 0; i < mPendingRegistrations.size(); i++) { final NsdServiceInfo other = mPendingRegistrations.valueAt(i).getServiceInfo(); if (MdnsUtils.equalsIgnoreDnsCase(info.getServiceName(), other.getServiceName()) @@ -347,10 +408,35 @@ public class MdnsAdvertiser { return -1; } + /** + * Get the ID of a conflicting registration due to host, or -1 if none. + * + * <p>It's valid that multiple registrations from the same user are using the same hostname. + * + * <p>If there's already another registration with the same hostname requested by another + * user, this is considered a conflict. + */ + int getConflictingRegistrationDueToHost(@NonNull NsdServiceInfo info, int clientUid) { + if (TextUtils.isEmpty(info.getHostname())) { + return -1; + } + for (int i = 0; i < mPendingRegistrations.size(); i++) { + final Registration otherRegistration = mPendingRegistrations.valueAt(i); + final NsdServiceInfo otherInfo = otherRegistration.getServiceInfo(); + if (clientUid != otherRegistration.mClientUid + && MdnsUtils.equalsIgnoreDnsCase( + info.getHostname(), otherInfo.getHostname())) { + return mPendingRegistrations.keyAt(i); + } + } + return -1; + } + /** * Add a service to advertise. * - * Conflicts must be checked via {@link #getConflictingService} before attempting to add. + * <p>Conflicts must be checked via {@link #getConflictingRegistrationDueToService} and + * {@link #getConflictingRegistrationDueToHost} before attempting to add. */ void addService(int id, @NonNull Registration registration) { mPendingRegistrations.put(id, registration); @@ -484,27 +570,35 @@ public class MdnsAdvertiser { } private static class Registration { - @NonNull - final String mOriginalName; + @Nullable + final String mOriginalServiceName; + @Nullable + final String mOriginalHostname; boolean mNotifiedRegistrationSuccess; - private int mConflictCount; + private int mServiceNameConflictCount; + private int mHostnameConflictCount; @NonNull private NsdServiceInfo mServiceInfo; + final int mClientUid; int mConflictDuringProbingCount; int mConflictAfterProbingCount; - private Registration(@NonNull NsdServiceInfo serviceInfo) { - this.mOriginalName = serviceInfo.getServiceName(); + + private Registration(@NonNull NsdServiceInfo serviceInfo, int clientUid) { + this.mOriginalServiceName = serviceInfo.getServiceName(); + this.mOriginalHostname = serviceInfo.getHostname(); this.mServiceInfo = serviceInfo; + this.mClientUid = clientUid; } - /** - * Matches between the NsdServiceInfo in the Registration and the provided argument. - */ - public boolean matches(@Nullable NsdServiceInfo newInfo) { - return Objects.equals(newInfo.getServiceName(), mOriginalName) && Objects.equals( - newInfo.getServiceType(), mServiceInfo.getServiceType()) && Objects.equals( - newInfo.getNetwork(), mServiceInfo.getNetwork()); + /** Check if the new {@link NsdServiceInfo} doesn't update any data other than subtypes. */ + public boolean isSubtypeOnlyUpdate(@NonNull NsdServiceInfo newInfo) { + return Objects.equals(newInfo.getServiceName(), mOriginalServiceName) + && Objects.equals(newInfo.getServiceType(), mServiceInfo.getServiceType()) + && newInfo.getPort() == mServiceInfo.getPort() + && Objects.equals(newInfo.getHostname(), mOriginalHostname) + && Objects.equals(newInfo.getHostAddresses(), mServiceInfo.getHostAddresses()) + && Objects.equals(newInfo.getNetwork(), mServiceInfo.getNetwork()); } /** @@ -521,8 +615,19 @@ public class MdnsAdvertiser { * @param newInfo New service info to use. * @param renameCount How many renames were done before reaching the current name. */ - private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) { - mConflictCount += renameCount; + private void updateForServiceConflict(@NonNull NsdServiceInfo newInfo, int renameCount) { + mServiceNameConflictCount += renameCount; + mServiceInfo = newInfo; + } + + /** + * Update the registration to use a different host name, after a conflict was found. + * + * @param newInfo New service info to use. + * @param renameCount How many renames were done before reaching the current name. + */ + private void updateForHostConflict(@NonNull NsdServiceInfo newInfo, int renameCount) { + mHostnameConflictCount += renameCount; mServiceInfo = newInfo; } @@ -538,7 +643,7 @@ public class MdnsAdvertiser { * @param renameCount How much to increase the number suffix for this conflict. */ @NonNull - public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) { + public NsdServiceInfo makeNewServiceInfoForServiceConflict(int renameCount) { // In case of conflict choose a different service name. After the first conflict use // "Name (2)", then "Name (3)" etc. // TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t @@ -547,13 +652,40 @@ public class MdnsAdvertiser { return newInfo; } + /** + * Make a new hostname for the registration, after a conflict was found. + * + * <p>If a name conflict was found during probing or because different advertising requests + * used the same name, the registration is attempted again with a new name (here using a + * number suffix, -1, -2, etc). Registration success is notified once probing succeeds with + * a new name. + * + * @param renameCount How much to increase the number suffix for this conflict. + */ + @NonNull + public NsdServiceInfo makeNewServiceInfoForHostConflict(int renameCount) { + // In case of conflict choose a different hostname. After the first conflict use + // "Name-2", then "Name-3" etc. + final NsdServiceInfo newInfo = new NsdServiceInfo(mServiceInfo); + newInfo.setHostname(getUpdatedHostname(renameCount)); + return newInfo; + } + private String getUpdatedServiceName(int renameCount) { - final String suffix = " (" + (mConflictCount + renameCount + 1) + ")"; - final String truncatedServiceName = MdnsUtils.truncateServiceName(mOriginalName, + final String suffix = " (" + (mServiceNameConflictCount + renameCount + 1) + ")"; + final String truncatedServiceName = MdnsUtils.truncateServiceName(mOriginalServiceName, MAX_LABEL_LENGTH - suffix.length()); return truncatedServiceName + suffix; } + private String getUpdatedHostname(int renameCount) { + final String suffix = "-" + (mHostnameConflictCount + renameCount + 1); + final String truncatedHostname = + MdnsUtils.truncateServiceName( + mOriginalHostname, MAX_LABEL_LENGTH - suffix.length()); + return truncatedHostname + suffix; + } + @NonNull public NsdServiceInfo getServiceInfo() { return mServiceInfo; @@ -681,9 +813,10 @@ public class MdnsAdvertiser { * @param id A unique ID for the service. * @param service The service info to advertise. * @param advertisingOptions The advertising options. + * @param clientUid The UID who wants to advertise the service. */ public void addOrUpdateService(int id, NsdServiceInfo service, - MdnsAdvertisingOptions advertisingOptions) { + MdnsAdvertisingOptions advertisingOptions, int clientUid) { checkThread(); final Registration existingRegistration = mRegistrations.get(id); final Network network = service.getNetwork(); @@ -695,7 +828,7 @@ public class MdnsAdvertiser { mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR); return; } - if (!(existingRegistration.matches(service))) { + if (!(existingRegistration.isSubtypeOnlyUpdate(service))) { mSharedLog.e("Update request can only update subType, serviceInfo: " + service + ", existing serviceInfo: " + existingRegistration.getServiceInfo()); mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR); @@ -715,7 +848,7 @@ public class MdnsAdvertiser { } mSharedLog.i("Adding service " + service + " with ID " + id + " and subtypes " + subtypes + " advertisingOptions " + advertisingOptions); - registration = new Registration(service); + registration = new Registration(service, clientUid); final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter; if (network == null) { // If registering on all networks, no advertiser must have conflicts diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java index 730bd7ef5b8010c491142c4463dcc1a48295cf0a..f1deab0e678710f1d476209b7251fe7a7119d6b3 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java @@ -22,10 +22,12 @@ import android.annotation.NonNull; import android.annotation.Nullable; import android.annotation.RequiresApi; import android.net.LinkAddress; +import android.net.nsd.NsdManager; import android.net.nsd.NsdServiceInfo; import android.os.Build; import android.os.Handler; import android.os.Looper; +import android.util.ArraySet; import com.android.internal.annotations.VisibleForTesting; import com.android.net.module.util.HexDump; @@ -37,6 +39,7 @@ import com.android.server.connectivity.mdns.util.MdnsUtils; import java.io.IOException; import java.net.InetSocketAddress; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -44,6 +47,9 @@ import java.util.Set; */ @RequiresApi(Build.VERSION_CODES.TIRAMISU) public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler { + public static final int CONFLICT_SERVICE = 1 << 0; + public static final int CONFLICT_HOST = 1 << 1; + private static final boolean DBG = MdnsAdvertiser.DBG; @VisibleForTesting public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L; @@ -85,10 +91,15 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand /** * Called by the advertiser when a conflict was found, during or after probing. * - * If a conflict is found during probing, the {@link #renameServiceForConflict} must be + * <p>If a conflict is found during probing, the {@link #renameServiceForConflict} must be * called to restart probing and attempt registration with a different name. + * + * <p>{@code conflictType} is a bitmap telling which part of the service is conflicting. See + * {@link MdnsInterfaceAdvertiser#CONFLICT_SERVICE} and {@link + * MdnsInterfaceAdvertiser#CONFLICT_HOST}. */ - void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId); + void onServiceConflict( + @NonNull MdnsInterfaceAdvertiser advertiser, int serviceId, int conflictType); /** * Called by the advertiser when it destroyed itself. @@ -384,8 +395,16 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand + packet.additionalRecords.size() + " additional from " + srcCopy); } - for (int conflictServiceId : mRecordRepository.getConflictingServices(packet)) { - mCbHandler.post(() -> mCb.onServiceConflict(this, conflictServiceId)); + Map<Integer, Integer> conflictingServices = + mRecordRepository.getConflictingServices(packet); + + for (Map.Entry<Integer, Integer> entry : conflictingServices.entrySet()) { + int serviceId = entry.getKey(); + int conflictType = entry.getValue(); + mCbHandler.post( + () -> { + mCb.onServiceConflict(this, serviceId, conflictType); + }); } // Even in case of conflict, add replies for other services. But in general conflicts would diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java index 78c3082260f5fe5cca9b3499f727ba3c457f46e6..fb454547b7ccb35a382b260699f5c7b78ea0a563 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java @@ -19,6 +19,8 @@ package com.android.server.connectivity.mdns; import static com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR; import static com.android.server.connectivity.mdns.MdnsConstants.IPV6_SOCKET_ADDR; import static com.android.server.connectivity.mdns.MdnsConstants.NO_PACKET; +import static com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST; +import static com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE; import android.annotation.NonNull; import android.annotation.Nullable; @@ -28,6 +30,8 @@ import android.net.nsd.NsdServiceInfo; import android.os.Build; import android.os.Looper; import android.os.SystemClock; +import android.text.TextUtils; +import android.util.ArrayMap; import android.util.ArraySet; import android.util.SparseArray; @@ -54,6 +58,8 @@ import java.util.Random; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Consumer; /** * A repository of records advertised through {@link MdnsInterfaceAdvertiser}. @@ -158,11 +164,13 @@ public class MdnsRecordRepository { public final List<RecordInfo<?>> allRecords; @NonNull public final List<RecordInfo<MdnsPointerRecord>> ptrRecords; - @NonNull + @Nullable public final RecordInfo<MdnsServiceRecord> srvRecord; - @NonNull + @Nullable public final RecordInfo<MdnsTextRecord> txtRecord; @NonNull + public final List<RecordInfo<MdnsInetAddressRecord>> addressRecords; + @NonNull public final NsdServiceInfo serviceInfo; /** @@ -202,65 +210,96 @@ public class MdnsRecordRepository { int repliedServiceCount, int sentPacketCount, boolean exiting, boolean isProbing) { this.serviceInfo = serviceInfo; - final String[] serviceType = splitServiceType(serviceInfo); - final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType); + final boolean hasService = !TextUtils.isEmpty(serviceInfo.getServiceType()); + final boolean hasCustomHost = !TextUtils.isEmpty(serviceInfo.getHostname()); + final String[] hostname = + hasCustomHost + ? new String[] {serviceInfo.getHostname(), LOCAL_TLD} + : deviceHostname; + final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(5); - // Service PTR records - ptrRecords = new ArrayList<>(serviceInfo.getSubtypes().size() + 1); - ptrRecords.add(new RecordInfo<>( - serviceInfo, - new MdnsPointerRecord( - serviceType, - 0L /* receiptTimeMillis */, - false /* cacheFlush */, - NON_NAME_RECORDS_TTL_MILLIS, - serviceName), - true /* sharedName */)); - for (String subtype : serviceInfo.getSubtypes()) { + if (hasService) { + final String[] serviceType = splitServiceType(serviceInfo); + final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType); + // Service PTR records + ptrRecords = new ArrayList<>(serviceInfo.getSubtypes().size() + 1); ptrRecords.add(new RecordInfo<>( - serviceInfo, - new MdnsPointerRecord( - MdnsUtils.constructFullSubtype(serviceType, subtype), - 0L /* receiptTimeMillis */, - false /* cacheFlush */, - NON_NAME_RECORDS_TTL_MILLIS, - serviceName), - true /* sharedName */)); - } + serviceInfo, + new MdnsPointerRecord( + serviceType, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + NON_NAME_RECORDS_TTL_MILLIS, + serviceName), + true /* sharedName */)); + for (String subtype : serviceInfo.getSubtypes()) { + ptrRecords.add(new RecordInfo<>( + serviceInfo, + new MdnsPointerRecord( + MdnsUtils.constructFullSubtype(serviceType, subtype), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + NON_NAME_RECORDS_TTL_MILLIS, + serviceName), + true /* sharedName */)); + } - srvRecord = new RecordInfo<>( - serviceInfo, - new MdnsServiceRecord(serviceName, - 0L /* receiptTimeMillis */, - true /* cacheFlush */, - NAME_RECORDS_TTL_MILLIS, 0 /* servicePriority */, 0 /* serviceWeight */, - serviceInfo.getPort(), - deviceHostname), - false /* sharedName */); - - txtRecord = new RecordInfo<>( - serviceInfo, - new MdnsTextRecord(serviceName, - 0L /* receiptTimeMillis */, - true /* cacheFlush */, // Service name is verified unique after probing - NON_NAME_RECORDS_TTL_MILLIS, - attrsToTextEntries(serviceInfo.getAttributes())), - false /* sharedName */); + srvRecord = new RecordInfo<>( + serviceInfo, + new MdnsServiceRecord(serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + NAME_RECORDS_TTL_MILLIS, + 0 /* servicePriority */, 0 /* serviceWeight */, + serviceInfo.getPort(), + hostname), + false /* sharedName */); + + txtRecord = new RecordInfo<>( + serviceInfo, + new MdnsTextRecord(serviceName, + 0L /* receiptTimeMillis */, + // Service name is verified unique after probing + true /* cacheFlush */, + NON_NAME_RECORDS_TTL_MILLIS, + attrsToTextEntries(serviceInfo.getAttributes())), + false /* sharedName */); + + allRecords.addAll(ptrRecords); + allRecords.add(srvRecord); + allRecords.add(txtRecord); + // Service type enumeration record (RFC6763 9.) + allRecords.add(new RecordInfo<>( + serviceInfo, + new MdnsPointerRecord( + DNS_SD_SERVICE_TYPE, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + NON_NAME_RECORDS_TTL_MILLIS, + serviceType), + true /* sharedName */)); + } else { + ptrRecords = Collections.emptyList(); + srvRecord = null; + txtRecord = null; + } - final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(5); - allRecords.addAll(ptrRecords); - allRecords.add(srvRecord); - allRecords.add(txtRecord); - // Service type enumeration record (RFC6763 9.) - allRecords.add(new RecordInfo<>( - serviceInfo, - new MdnsPointerRecord( - DNS_SD_SERVICE_TYPE, - 0L /* receiptTimeMillis */, - false /* cacheFlush */, - NON_NAME_RECORDS_TTL_MILLIS, - serviceType), - true /* sharedName */)); + if (hasCustomHost) { + addressRecords = new ArrayList<>(serviceInfo.getHostAddresses().size()); + for (InetAddress address : serviceInfo.getHostAddresses()) { + addressRecords.add(new RecordInfo<>( + serviceInfo, + new MdnsInetAddressRecord(hostname, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + NAME_RECORDS_TTL_MILLIS, + address), + false /* sharedName */)); + } + allRecords.addAll(addressRecords); + } else { + addressRecords = Collections.emptyList(); + } this.allRecords = Collections.unmodifiableList(allRecords); this.repliedServiceCount = repliedServiceCount; @@ -368,32 +407,38 @@ public class MdnsRecordRepository { /** * @return The ID of the service identified by its name, or -1 if none. */ - private int getServiceByName(@NonNull String serviceName) { + private int getServiceByName(@Nullable String serviceName) { + if (TextUtils.isEmpty(serviceName)) { + return -1; + } for (int i = 0; i < mServices.size(); i++) { final ServiceRegistration registration = mServices.valueAt(i); - if (MdnsUtils.equalsIgnoreDnsCase(serviceName, - registration.serviceInfo.getServiceName())) { + if (MdnsUtils.equalsIgnoreDnsCase( + serviceName, registration.serviceInfo.getServiceName())) { return mServices.keyAt(i); } } return -1; } - private MdnsProber.ProbingInfo makeProbingInfo(int serviceId, - @NonNull MdnsServiceRecord srvRecord, - @NonNull List<MdnsInetAddressRecord> inetAddressRecords) { + private MdnsProber.ProbingInfo makeProbingInfo( + int serviceId, ServiceRegistration registration) { final List<MdnsRecord> probingRecords = new ArrayList<>(); // Probe with cacheFlush cleared; it is set when announcing, as it was verified unique: // RFC6762 10.2 - probingRecords.add(new MdnsServiceRecord(srvRecord.getName(), - 0L /* receiptTimeMillis */, - false /* cacheFlush */, - srvRecord.getTtl(), - srvRecord.getServicePriority(), srvRecord.getServiceWeight(), - srvRecord.getServicePort(), - srvRecord.getServiceHost())); - - for (MdnsInetAddressRecord inetAddressRecord : inetAddressRecords) { + if (registration.srvRecord != null) { + MdnsServiceRecord srvRecord = registration.srvRecord.record; + probingRecords.add(new MdnsServiceRecord(srvRecord.getName(), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + srvRecord.getTtl(), + srvRecord.getServicePriority(), srvRecord.getServiceWeight(), + srvRecord.getServicePort(), + srvRecord.getServiceHost())); + } + + for (MdnsInetAddressRecord inetAddressRecord : + makeProbingInetAddressRecords(registration.serviceInfo)) { probingRecords.add(new MdnsInetAddressRecord(inetAddressRecord.getName(), 0L /* receiptTimeMillis */, false /* cacheFlush */, @@ -510,6 +555,9 @@ public class MdnsRecordRepository { public MdnsReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) { final long now = SystemClock.elapsedRealtime(); + // TODO: b/322142420 - Set<RecordInfo<?>> may contain duplicate records wrapped in different + // RecordInfo<?>s when custom host is enabled. + // Use LinkedHashSet for preserving the insert order of the RRs, so that RRs of the same // service or host are grouped together (which is more developer-friendly). final Set<RecordInfo<?>> answerInfo = new LinkedHashSet<>(); @@ -520,8 +568,10 @@ public class MdnsRecordRepository { for (MdnsRecord question : packet.questions) { // Add answers from general records if (addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */, - null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicastEnabled, - now, answerInfo, additionalAnswerInfo, Collections.emptyList())) { + null /* serviceSrvRecord */, null /* serviceTxtRecord */, + null /* hostname */, + replyUnicastEnabled, now, answerInfo, additionalAnswerInfo, + Collections.emptyList())) { replyUnicast &= question.isUnicastReplyRequested(); } @@ -530,7 +580,9 @@ public class MdnsRecordRepository { final ServiceRegistration registration = mServices.valueAt(i); if (registration.exiting || registration.isProbing) continue; if (addReplyFromService(question, registration.allRecords, registration.ptrRecords, - registration.srvRecord, registration.txtRecord, replyUnicastEnabled, now, + registration.srvRecord, registration.txtRecord, + registration.serviceInfo.getHostname(), + replyUnicastEnabled, now, answerInfo, additionalAnswerInfo, packet.answers)) { replyUnicast &= question.isUnicastReplyRequested(); registration.repliedServiceCount++; @@ -548,7 +600,12 @@ public class MdnsRecordRepository { final List<MdnsRecord> additionalAnswerRecords = new ArrayList<>(additionalAnswerInfo.size()); for (RecordInfo<?> info : additionalAnswerInfo) { - additionalAnswerRecords.add(info.record); + // Different RecordInfos may contain the same record. + // For example, when there are multiple services referring to the same custom host, + // there are multiple RecordInfos containing the same address record. + if (!additionalAnswerRecords.contains(info.record)) { + additionalAnswerRecords.add(info.record); + } } // RFC6762 6.1: negative responses @@ -618,7 +675,10 @@ public class MdnsRecordRepository { if (!replyUnicast) { info.lastAdvertisedTimeMs = info.lastSentTimeMs; } - answerRecords.add(info.record); + // Different RecordInfos may the contain the same record + if (!answerRecords.contains(info.record)) { + answerRecords.add(info.record); + } } return new MdnsReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest, src, @@ -642,6 +702,7 @@ public class MdnsRecordRepository { @Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords, @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord, @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord, + @Nullable String hostname, boolean replyUnicastEnabled, long now, @NonNull Set<RecordInfo<?>> answerInfo, @NonNull Set<RecordInfo<?>> additionalAnswerInfo, @NonNull List<MdnsRecord> knownAnswerRecords) { @@ -735,11 +796,7 @@ public class MdnsRecordRepository { // RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) { - for (RecordInfo<?> record : mGeneralRecords) { - if (record.record instanceof MdnsInetAddressRecord) { - additionalAnswerInfo.add(record); - } - } + additionalAnswerInfo.addAll(getInetAddressRecordsForHostname(hostname)); } return true; } @@ -853,29 +910,46 @@ public class MdnsRecordRepository { MdnsProber.ProbingInfo probeSuccessInfo) throws IOException { - final ServiceRegistration registration = mServices.get(probeSuccessInfo.getServiceId()); - if (registration == null) throw new IOException( - "Service is not registered: " + probeSuccessInfo.getServiceId()); + int serviceId = probeSuccessInfo.getServiceId(); + final ServiceRegistration registration = mServices.get(serviceId); + if (registration == null) { + throw new IOException("Service is not registered: " + serviceId); + } registration.setProbing(false); - final ArrayList<MdnsRecord> answers = new ArrayList<>(); + final Set<MdnsRecord> answersSet = new LinkedHashSet<>(); final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>(); - // Interface address records in general records - for (RecordInfo<?> record : mGeneralRecords) { - answers.add(record.record); + // When using default host, add interface address records from general records + if (TextUtils.isEmpty(registration.serviceInfo.getHostname())) { + for (RecordInfo<?> record : mGeneralRecords) { + answersSet.add(record.record); + } + } else { + // TODO: b/321617573 - include PTR records for addresses + // The custom host may have more addresses in other registrations + forEachActiveServiceRegistrationWithHostname( + registration.serviceInfo.getHostname(), + (id, otherRegistration) -> { + if (otherRegistration.isProbing) { + return; + } + for (RecordInfo<?> addressRecordInfo : otherRegistration.addressRecords) { + answersSet.add(addressRecordInfo.record); + } + }); } // All service records for (RecordInfo<?> info : registration.allRecords) { - answers.add(info.record); + answersSet.add(info.record); } addNsecRecordsForUniqueNames(additionalAnswers, mGeneralRecords.iterator(), registration.allRecords.iterator()); - return new MdnsAnnouncer.AnnouncementInfo(probeSuccessInfo.getServiceId(), - answers, additionalAnswers); + return new MdnsAnnouncer.AnnouncementInfo( + probeSuccessInfo.getServiceId(), new ArrayList<>(answersSet), additionalAnswers); } /** @@ -894,8 +968,13 @@ public class MdnsRecordRepository { for (RecordInfo<MdnsPointerRecord> ptrRecord : registration.ptrRecords) { answers.add(ptrRecord.record); } - answers.add(registration.srvRecord.record); - answers.add(registration.txtRecord.record); + if (registration.srvRecord != null) { + answers.add(registration.srvRecord.record); + } + if (registration.txtRecord != null) { + answers.add(registration.txtRecord.record); + } + // TODO: Support custom host. It currently only supports default host. for (RecordInfo<?> record : mGeneralRecords) { if (record.record instanceof MdnsInetAddressRecord) { answers.add(record.record); @@ -910,70 +989,181 @@ public class MdnsRecordRepository { Collections.emptyList() /* additionalRecords */); } + /** Check if the record is in any service registration */ + private boolean hasInetAddressRecord(@NonNull MdnsInetAddressRecord record) { + for (int i = 0; i < mServices.size(); i++) { + final ServiceRegistration registration = mServices.valueAt(i); + if (registration.exiting) continue; + + for (RecordInfo<MdnsInetAddressRecord> localRecord : registration.addressRecords) { + if (Objects.equals(localRecord.record, record)) { + return true; + } + } + } + return false; + } + /** * Get the service IDs of services conflicting with a received packet. + * + * <p>It returns a Map of service ID => conflict type. Conflict type is a bitmap telling which + * part of the service is conflicting. See {@link MdnsInterfaceAdvertiser#CONFLICT_SERVICE} and + * {@link MdnsInterfaceAdvertiser#CONFLICT_HOST}. */ - public Set<Integer> getConflictingServices(MdnsPacket packet) { + public Map<Integer, Integer> getConflictingServices(MdnsPacket packet) { // Avoid allocating a new set for each incoming packet: use an empty set by default. - Set<Integer> conflicting = Collections.emptySet(); + Map<Integer, Integer> conflicting = Collections.emptyMap(); for (MdnsRecord record : packet.answers) { for (int i = 0; i < mServices.size(); i++) { final ServiceRegistration registration = mServices.valueAt(i); if (registration.exiting) continue; - // Only look for conflicts in service name, as a different service name can be used - // if there is a conflict, but there is nothing actionable if any other conflict - // happens. In fact probing is only done for the service name in the SRV record. - // This means only SRV and TXT records need to be checked. - final RecordInfo<MdnsServiceRecord> srvRecord = registration.srvRecord; - if (!MdnsUtils.equalsDnsLabelIgnoreDnsCase(record.getName(), - srvRecord.record.getName())) { - continue; - } + int conflictType = 0; - // As per RFC6762 9., it's fine if the "conflict" is an identical record with same - // data. - if (record instanceof MdnsServiceRecord) { - final MdnsServiceRecord local = srvRecord.record; - final MdnsServiceRecord other = (MdnsServiceRecord) record; - // Note "equals" does not consider TTL or receipt time, as intended here - if (Objects.equals(local, other)) { - continue; - } + if (conflictForService(record, registration)) { + conflictType |= CONFLICT_SERVICE; } - if (record instanceof MdnsTextRecord) { - final MdnsTextRecord local = registration.txtRecord.record; - final MdnsTextRecord other = (MdnsTextRecord) record; - if (Objects.equals(local, other)) { - continue; - } + if (conflictForHost(record, registration)) { + conflictType |= CONFLICT_HOST; } - if (conflicting.size() == 0) { - // Conflict was found: use a mutable set - conflicting = new ArraySet<>(); + if (conflictType != 0) { + if (conflicting.isEmpty()) { + // Conflict was found: use a mutable set + conflicting = new ArrayMap<>(); + } + final int serviceId = mServices.keyAt(i); + conflicting.put(serviceId, conflictType); } - final int serviceId = mServices.keyAt(i); - conflicting.add(serviceId); } } return conflicting; } - private List<MdnsInetAddressRecord> makeProbingInetAddressRecords() { + + private static boolean conflictForService( + @NonNull MdnsRecord record, @NonNull ServiceRegistration registration) { + if (registration.srvRecord == null) { + return false; + } + + final RecordInfo<MdnsServiceRecord> srvRecord = registration.srvRecord; + if (!MdnsUtils.equalsDnsLabelIgnoreDnsCase(record.getName(), srvRecord.record.getName())) { + return false; + } + + // As per RFC6762 9., it's fine if the "conflict" is an identical record with same + // data. + if (record instanceof MdnsServiceRecord) { + final MdnsServiceRecord local = srvRecord.record; + final MdnsServiceRecord other = (MdnsServiceRecord) record; + // Note "equals" does not consider TTL or receipt time, as intended here + if (Objects.equals(local, other)) { + return false; + } + } + + if (record instanceof MdnsTextRecord) { + final MdnsTextRecord local = registration.txtRecord.record; + final MdnsTextRecord other = (MdnsTextRecord) record; + if (Objects.equals(local, other)) { + return false; + } + } + return true; + } + + private boolean conflictForHost( + @NonNull MdnsRecord record, @NonNull ServiceRegistration registration) { + // Only custom hosts are checked. When using the default host, the hostname is derived from + // a UUID and it's supposed to be unique. + if (registration.serviceInfo.getHostname() == null) { + return false; + } + + // The record's name cannot be registered by NsdManager so it's not a conflict. + if (record.getName().length != 2 || !record.getName()[1].equals(LOCAL_TLD)) { + return false; + } + + // Different names. There won't be a conflict. + if (!MdnsUtils.equalsIgnoreDnsCase( + record.getName()[0], registration.serviceInfo.getHostname())) { + return false; + } + + // If this registration has any address record and there's no identical record in the + // repository, it's a conflict. There will be no conflict if no registration has addresses + // for that hostname. + if (record instanceof MdnsInetAddressRecord) { + if (!registration.addressRecords.isEmpty()) { + return !hasInetAddressRecord((MdnsInetAddressRecord) record); + } + } + + return false; + } + + private List<RecordInfo<MdnsInetAddressRecord>> getInetAddressRecordsForHostname( + @Nullable String hostname) { + List<RecordInfo<MdnsInetAddressRecord>> records = new ArrayList<>(); + if (TextUtils.isEmpty(hostname)) { + forEachAddressRecord(mGeneralRecords, records::add); + } else { + forEachActiveServiceRegistrationWithHostname( + hostname, + (id, service) -> { + if (service.isProbing) return; + records.addAll(service.addressRecords); + }); + } + return records; + } + + private List<MdnsInetAddressRecord> makeProbingInetAddressRecords( + @NonNull NsdServiceInfo serviceInfo) { final List<MdnsInetAddressRecord> records = new ArrayList<>(); - if (mMdnsFeatureFlags.mIncludeInetAddressRecordsInProbing) { - for (RecordInfo<?> record : mGeneralRecords) { - if (record.record instanceof MdnsInetAddressRecord) { - records.add((MdnsInetAddressRecord) record.record); - } + if (TextUtils.isEmpty(serviceInfo.getHostname())) { + if (mMdnsFeatureFlags.mIncludeInetAddressRecordsInProbing) { + forEachAddressRecord(mGeneralRecords, r -> records.add(r.record)); } + } else { + forEachActiveServiceRegistrationWithHostname( + serviceInfo.getHostname(), + (id, service) -> { + for (RecordInfo<MdnsInetAddressRecord> recordInfo : + service.addressRecords) { + records.add(recordInfo.record); + } + }); } return records; } + private static void forEachAddressRecord( + List<RecordInfo<?>> records, Consumer<RecordInfo<MdnsInetAddressRecord>> consumer) { + for (RecordInfo<?> record : records) { + if (record.record instanceof MdnsInetAddressRecord) { + consumer.accept((RecordInfo<MdnsInetAddressRecord>) record); + } + } + } + + private void forEachActiveServiceRegistrationWithHostname( + @NonNull String hostname, BiConsumer<Integer, ServiceRegistration> consumer) { + for (int i = 0; i < mServices.size(); ++i) { + int id = mServices.keyAt(i); + ServiceRegistration service = mServices.valueAt(i); + if (service.exiting) continue; + if (MdnsUtils.equalsIgnoreDnsCase(service.serviceInfo.getHostname(), hostname)) { + consumer.accept(id, service); + } + } + } + /** * (Re)set a service to the probing state. * @return The {@link MdnsProber.ProbingInfo} to send for probing. @@ -984,8 +1174,8 @@ public class MdnsRecordRepository { if (registration == null) return null; registration.setProbing(true); - return makeProbingInfo( - serviceId, registration.srvRecord.record, makeProbingInetAddressRecords()); + + return makeProbingInfo(serviceId, registration); } /** @@ -1021,8 +1211,7 @@ public class MdnsRecordRepository { final ServiceRegistration newService = new ServiceRegistration(mDeviceHostname, newInfo, existing.repliedServiceCount, existing.sentPacketCount); mServices.put(serviceId, newService); - return makeProbingInfo( - serviceId, newService.srvRecord.record, makeProbingInetAddressRecords()); + return makeProbingInfo(serviceId, newService); } /** diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java index 8fc81148a450d565aab32703845b5cf9d5e251dd..d553210fa1f10f3d5b0d719206909fff95caa05a 100644 --- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java +++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java @@ -86,7 +86,10 @@ public class MdnsUtils { /** * Compare two strings by DNS case-insensitive lowercase. */ - public static boolean equalsIgnoreDnsCase(@NonNull String a, @NonNull String b) { + public static boolean equalsIgnoreDnsCase(@Nullable String a, @Nullable String b) { + if (a == null || b == null) { + return a == null && b == null; + } if (a.length() != b.length()) return false; for (int i = 0; i < a.length(); i++) { if (toDnsLowerCase(a.charAt(i)) != toDnsLowerCase(b.charAt(i))) { diff --git a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java index 79c4980a5c3d32b97ca1c6fc044907a165868b6c..8e89037724221fa950ca7883b8b5e222deccb11e 100644 --- a/tests/common/java/android/net/nsd/NsdServiceInfoTest.java +++ b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java @@ -119,6 +119,7 @@ public class NsdServiceInfoTest { fullInfo.setSubtypes(Set.of("_thread", "_matter")); fullInfo.setPort(4242); fullInfo.setHostAddresses(List.of(IPV4_ADDRESS)); + fullInfo.setHostname("home"); fullInfo.setNetwork(new Network(123)); fullInfo.setInterfaceIndex(456); checkParcelable(fullInfo); @@ -134,6 +135,7 @@ public class NsdServiceInfoTest { attributedInfo.setServiceType("_kitten._tcp"); attributedInfo.setPort(4242); attributedInfo.setHostAddresses(List.of(IPV6_ADDRESS, IPV4_ADDRESS)); + attributedInfo.setHostname("home"); attributedInfo.setAttribute("color", "pink"); attributedInfo.setAttribute("sound", (new String("ã«ã‚ƒã‚")).getBytes("UTF-8")); attributedInfo.setAttribute("adorable", (String) null); @@ -169,6 +171,7 @@ public class NsdServiceInfoTest { assertEquals(original.getServiceName(), result.getServiceName()); assertEquals(original.getServiceType(), result.getServiceType()); assertEquals(original.getHost(), result.getHost()); + assertEquals(original.getHostname(), result.getHostname()); assertTrue(original.getPort() == result.getPort()); assertEquals(original.getNetwork(), result.getNetwork()); assertEquals(original.getInterfaceIndex(), result.getInterfaceIndex()); diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt index 8f9f8c7043bdf7bb9dd594b7e6c9ecbefdeb5e76..c368d5bc3e74b6d04b7dcbe21540b2ef6f06c6d6 100644 --- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt +++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt @@ -127,6 +127,7 @@ import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith +import kotlin.test.assertNotEquals private const val TAG = "NsdManagerTest" private const val TIMEOUT_MS = 2000L @@ -162,7 +163,11 @@ class NsdManagerTest { private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! } private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000)) private val serviceName2 = "NsdTest%09d".format(Random().nextInt(1_000_000_000)) + private val serviceName3 = "NsdTest%09d".format(Random().nextInt(1_000_000_000)) private val serviceType = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000)) + private val serviceType2 = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000)) + private val customHostname = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000)) + private val customHostname2 = "NsdTestHost%09d".format(Random().nextInt(1_000_000_000)) private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName) private val ctsNetUtils by lazy{ CtsNetUtils(context) } @@ -1188,6 +1193,84 @@ class NsdManagerTest { } } + @Test + fun testRegisterServiceWithCustomHostAndAddresses_conflictDuringProbing_hostRenamed() { + val si = makeTestServiceInfo(testNetwork1.network).apply { + hostname = customHostname + hostAddresses = listOf( + parseNumericAddress("192.0.2.24"), + parseNumericAddress("2001:db8::3")) + } + + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + // Register service on testNetwork1 + val registrationRecord = NsdRegistrationRecord() + nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() }, + registrationRecord) + + tryTest { + assertNotNull(packetReader.pollForProbe(serviceName, serviceType), + "Did not find a probe for the service") + packetReader.sendResponse(buildConflictingAnnouncementForCustomHost()) + + // Registration must use an updated hostname to avoid the conflict + val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS) + // Service name is not renamed because there's no conflict on the service name. + // TODO: b/283053491 - enable this check +// assertEquals(serviceName, cb.serviceInfo.serviceName) + val hostname = cb.serviceInfo.hostname ?: fail("Missing hostname") + hostname.let { + assertTrue("Unexpected registered hostname: $it", + it.startsWith(customHostname) && it != customHostname) + } + } cleanupStep { + nsdManager.unregisterService(registrationRecord) + registrationRecord.expectCallback<ServiceUnregistered>() + } cleanup { + packetReader.handler.post { packetReader.stop() } + handlerThread.waitForIdle(TIMEOUT_MS) + } + } + + @Test + fun testRegisterServiceWithCustomHostNoAddresses_noConflictDuringProbing_notRenamed() { + val si = makeTestServiceInfo(testNetwork1.network).apply { + hostname = customHostname + } + + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + // Register service on testNetwork1 + val registrationRecord = NsdRegistrationRecord() + nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() }, + registrationRecord) + + tryTest { + assertNotNull(packetReader.pollForProbe(serviceName, serviceType), + "Did not find a probe for the service") + // Not a conflict because no record is registered for the hostname + packetReader.sendResponse(buildConflictingAnnouncementForCustomHost()) + + // Registration is not renamed because there's no conflict + val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS) + assertEquals(serviceName, cb.serviceInfo.serviceName) + assertEquals(customHostname, cb.serviceInfo.hostname) + } cleanupStep { + nsdManager.unregisterService(registrationRecord) + registrationRecord.expectCallback<ServiceUnregistered>() + } cleanup { + packetReader.handler.post { packetReader.stop() } + handlerThread.waitForIdle(TIMEOUT_MS) + } + } + @Test fun testRegisterWithConflictAfterProbing() { // This test requires shims supporting T+ APIs (NsdServiceInfo.network) @@ -1263,6 +1346,52 @@ class NsdManagerTest { } } + // TODO: b/322282952 - Add the test case that the hostname is renamed due to a conflict after + // probing succeeded. + + @Test + fun testRegisterServiceWithCustomHostNoAddresses_noConflictAfterProbing_notRenamed() { + val si = makeTestServiceInfo(testNetwork1.network).apply { + hostname = customHostname + } + + // Register service on testNetwork1 + val registrationRecord = NsdRegistrationRecord() + val discoveryRecord = NsdDiscoveryRecord() + val registeredService = registerService(registrationRecord, si) + val packetReader = TapPacketReader(Handler(handlerThread.looper), + testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */) + packetReader.startAsyncForTest() + handlerThread.waitForIdle(TIMEOUT_MS) + + tryTest { + assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType), + "No announcements sent after initial probing") + + assertEquals(si.serviceName, registeredService.serviceName) + assertEquals(si.hostname, registeredService.hostname) + + // Send a conflicting announcement + val conflictingAnnouncement = buildConflictingAnnouncementForCustomHost() + packetReader.sendResponse(conflictingAnnouncement) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, { it.run() }, discoveryRecord) + + // The service is not renamed + discoveryRecord.waitForServiceDiscovered(si.serviceName, serviceType) + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + discoveryRecord.expectCallback<DiscoveryStopped>() + } cleanupStep { + nsdManager.unregisterService(registrationRecord) + registrationRecord.expectCallback<ServiceUnregistered>() + } cleanup { + packetReader.handler.post { packetReader.stop() } + handlerThread.waitForIdle(TIMEOUT_MS) + } + } + // Test that even if only a PTR record is received as a reply when discovering, without the // SRV, TXT, address records as recommended (but not mandated) by RFC 6763 12, the service can // still be discovered. @@ -1447,6 +1576,212 @@ class NsdManagerTest { return Inet6Address.getByAddress(addrBytes) as Inet6Address } + @Test + fun testAdvertisingAndDiscovery_servicesWithCustomHost_customHostAddressesFound() { + val hostAddresses1 = listOf( + parseNumericAddress("192.0.2.23"), + parseNumericAddress("2001:db8::1"), + parseNumericAddress("2001:db8::2")) + val hostAddresses2 = listOf( + parseNumericAddress("192.0.2.24"), + parseNumericAddress("2001:db8::3")) + val si1 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceName = serviceName + it.serviceType = serviceType + it.port = TEST_PORT + it.hostname = customHostname + it.hostAddresses = hostAddresses1 + } + val si2 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceName = serviceName2 + it.serviceType = serviceType + it.port = TEST_PORT + 1 + it.hostname = customHostname2 + it.hostAddresses = hostAddresses2 + } + val registrationRecord1 = NsdRegistrationRecord() + val registrationRecord2 = NsdRegistrationRecord() + + val discoveryRecord1 = NsdDiscoveryRecord() + val discoveryRecord2 = NsdDiscoveryRecord() + tryTest { + registerService(registrationRecord1, si1) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, Executor { it.run() }, discoveryRecord1) + + val discoveredInfo = discoveryRecord1.waitForServiceDiscovered( + serviceName, serviceType, testNetwork1.network) + val resolvedInfo = resolveService(discoveredInfo) + + assertEquals(TEST_PORT, resolvedInfo.port) + assertEquals(si1.hostname, resolvedInfo.hostname) + assertAddressEquals(hostAddresses1, resolvedInfo.hostAddresses) + + registerService(registrationRecord2, si2) + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, Executor { it.run() }, discoveryRecord2) + + val discoveredInfo2 = discoveryRecord2.waitForServiceDiscovered( + serviceName2, serviceType, testNetwork1.network) + val resolvedInfo2 = resolveService(discoveredInfo2) + + assertEquals(TEST_PORT + 1, resolvedInfo2.port) + assertEquals(si2.hostname, resolvedInfo2.hostname) + assertAddressEquals(hostAddresses2, resolvedInfo2.hostAddresses) + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord1) + nsdManager.stopServiceDiscovery(discoveryRecord2) + + discoveryRecord1.expectCallbackEventually<DiscoveryStopped>() + discoveryRecord2.expectCallbackEventually<DiscoveryStopped>() + } cleanup { + nsdManager.unregisterService(registrationRecord1) + nsdManager.unregisterService(registrationRecord2) + } + } + + @Test + fun testAdvertisingAndDiscovery_multipleRegistrationsForSameCustomHost_unionOfAddressesFound() { + val hostAddresses1 = listOf( + parseNumericAddress("192.0.2.23"), + parseNumericAddress("2001:db8::1"), + parseNumericAddress("2001:db8::2")) + val hostAddresses2 = listOf( + parseNumericAddress("192.0.2.24"), + parseNumericAddress("2001:db8::3")) + val hostAddresses3 = listOf( + parseNumericAddress("2001:db8::3"), + parseNumericAddress("2001:db8::5")) + val si1 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.hostname = customHostname + it.hostAddresses = hostAddresses1 + } + val si2 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceName = serviceName + it.serviceType = serviceType + it.port = TEST_PORT + it.hostname = customHostname + it.hostAddresses = hostAddresses2 + } + val si3 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceName = serviceName3 + it.serviceType = serviceType + it.port = TEST_PORT + 1 + it.hostname = customHostname + it.hostAddresses = hostAddresses3 + } + + val registrationRecord1 = NsdRegistrationRecord() + val registrationRecord2 = NsdRegistrationRecord() + val registrationRecord3 = NsdRegistrationRecord() + + val discoveryRecord = NsdDiscoveryRecord() + tryTest { + registerService(registrationRecord1, si1) + registerService(registrationRecord2, si2) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, Executor { it.run() }, discoveryRecord) + + val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered( + serviceName, serviceType, testNetwork1.network) + val resolvedInfo1 = resolveService(discoveredInfo1) + + assertEquals(TEST_PORT, resolvedInfo1.port) + assertEquals(si1.hostname, resolvedInfo1.hostname) + assertAddressEquals( + hostAddresses1 + hostAddresses2, + resolvedInfo1.hostAddresses) + + registerService(registrationRecord3, si3) + + val discoveredInfo2 = discoveryRecord.waitForServiceDiscovered( + serviceName3, serviceType, testNetwork1.network) + val resolvedInfo2 = resolveService(discoveredInfo2) + + assertEquals(TEST_PORT + 1, resolvedInfo2.port) + assertEquals(si2.hostname, resolvedInfo2.hostname) + assertAddressEquals( + hostAddresses1 + hostAddresses2 + hostAddresses3, + resolvedInfo2.hostAddresses) + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + + discoveryRecord.expectCallbackEventually<DiscoveryStopped>() + } cleanup { + nsdManager.unregisterService(registrationRecord1) + nsdManager.unregisterService(registrationRecord2) + nsdManager.unregisterService(registrationRecord3) + } + } + + @Test + fun testAdvertisingAndDiscovery_servicesWithTheSameCustomHostAddressOmitted_addressesFound() { + val hostAddresses = listOf( + parseNumericAddress("192.0.2.23"), + parseNumericAddress("2001:db8::1"), + parseNumericAddress("2001:db8::2")) + val si1 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceType = serviceType + it.serviceName = serviceName + it.port = TEST_PORT + it.hostname = customHostname + it.hostAddresses = hostAddresses + } + val si2 = NsdServiceInfo().also { + it.network = testNetwork1.network + it.serviceType = serviceType + it.serviceName = serviceName2 + it.port = TEST_PORT + 1 + it.hostname = customHostname + } + + val registrationRecord1 = NsdRegistrationRecord() + val registrationRecord2 = NsdRegistrationRecord() + + val discoveryRecord = NsdDiscoveryRecord() + tryTest { + registerService(registrationRecord1, si1) + + nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, + testNetwork1.network, Executor { it.run() }, discoveryRecord) + + val discoveredInfo1 = discoveryRecord.waitForServiceDiscovered( + serviceName, serviceType, testNetwork1.network) + val resolvedInfo1 = resolveService(discoveredInfo1) + + assertEquals(serviceName, discoveredInfo1.serviceName) + assertEquals(TEST_PORT, resolvedInfo1.port) + assertEquals(si1.hostname, resolvedInfo1.hostname) + assertAddressEquals(hostAddresses, resolvedInfo1.hostAddresses) + + registerService(registrationRecord2, si2) + + val discoveredInfo2 = discoveryRecord.waitForServiceDiscovered( + serviceName2, serviceType, testNetwork1.network) + val resolvedInfo2 = resolveService(discoveredInfo2) + + assertEquals(serviceName2, discoveredInfo2.serviceName) + assertEquals(TEST_PORT + 1, resolvedInfo2.port) + assertEquals(si2.hostname, resolvedInfo2.hostname) + assertAddressEquals(hostAddresses, resolvedInfo2.hostAddresses) + } cleanupStep { + nsdManager.stopServiceDiscovery(discoveryRecord) + + discoveryRecord.expectCallback<DiscoveryStopped>() + } cleanup { + nsdManager.unregisterService(registrationRecord1) + nsdManager.unregisterService(registrationRecord2) + } + } + private fun buildConflictingAnnouncement(): ByteBuffer { /* Generated with: @@ -1463,6 +1798,22 @@ class NsdManagerTest { return buildMdnsPacket(mdnsPayload) } + private fun buildConflictingAnnouncementForCustomHost(): ByteBuffer { + /* + Generated with scapy: + raw(DNS(rd=0, qr=1, aa=1, qd = None, an = + DNSRR(rrname='NsdTestHost123456789.local', type=28, rclass=1, ttl=120, + rdata='2001:db8::321') + )).hex() + */ + val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000144e7364" + + "54657374486f7374313233343536373839056c6f63616c00001c000100000078001020010db80000" + + "00000000000000000321") + replaceCustomHostnameWithTestSuffix(mdnsPayload) + + return buildMdnsPacket(mdnsPayload) + } + /** * Replaces occurrences of "NsdTest123456789" and "_nmt123456789" in mDNS payload with the * actual random name and type that are used by the test. @@ -1479,6 +1830,19 @@ class NsdManagerTest { replaceAll(packetBuffer, testPacketTypePrefix, encodedTypePrefix) } + /** + * Replaces occurrences of "NsdTestHost123456789" in mDNS payload with the + * actual random host name that are used by the test. + */ + private fun replaceCustomHostnameWithTestSuffix(mdnsPayload: ByteArray) { + // Test custom hostnames have consistent length and are always ASCII + val testPacketName = "NsdTestHost123456789".encodeToByteArray() + val encodedHostname = customHostname.encodeToByteArray() + + val packetBuffer = ByteBuffer.wrap(mdnsPayload) + replaceAll(packetBuffer, testPacketName, encodedHostname) + } + private tailrec fun replaceAll(buffer: ByteBuffer, source: ByteArray, replacement: ByteArray) { assertEquals(source.size, replacement.size) val index = buffer.array().indexOf(source) @@ -1577,3 +1941,9 @@ private fun ByteArray?.utf8ToString(): String { if (this == null) return "" return String(this, StandardCharsets.UTF_8) } + +private fun assertAddressEquals(expected: List<InetAddress>, actual: List<InetAddress>) { + // No duplicate addresses in the actual address list + assertEquals(actual.toSet().size, actual.size) + assertEquals(expected.toSet(), actual.toSet()) +} \ No newline at end of file diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java index aabe8d3423022b0ae47141c6e96222e19d87e546..951675ca797c979bba7058ba898f7825e5eaf914 100644 --- a/tests/unit/java/android/net/nsd/NsdManagerTest.java +++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java @@ -52,6 +52,9 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import java.net.InetAddress; +import java.util.List; + @DevSdkIgnoreRunner.MonitorThreadLeak @RunWith(DevSdkIgnoreRunner.class) @SmallTest @@ -370,6 +373,9 @@ public class NsdManagerTest { NsdManager.RegistrationListener listener1 = mock(NsdManager.RegistrationListener.class); NsdManager.DiscoveryListener listener2 = mock(NsdManager.DiscoveryListener.class); NsdManager.ResolveListener listener3 = mock(NsdManager.ResolveListener.class); + NsdManager.RegistrationListener listener4 = mock(NsdManager.RegistrationListener.class); + NsdManager.RegistrationListener listener5 = mock(NsdManager.RegistrationListener.class); + NsdManager.RegistrationListener listener6 = mock(NsdManager.RegistrationListener.class); NsdServiceInfo invalidService = new NsdServiceInfo(null, null); NsdServiceInfo validService = new NsdServiceInfo("a_name", "_a_type._tcp"); @@ -379,6 +385,7 @@ public class NsdManagerTest { "_a_type._tcp,_sub1,_s2"); NsdServiceInfo otherSubtypeUpdate = new NsdServiceInfo("a_name", "_a_type._tcp,_sub1,_s3"); NsdServiceInfo dotSyntaxSubtypeUpdate = new NsdServiceInfo("a_name", "_sub1._a_type._tcp"); + validService.setPort(2222); otherServiceWithSubtype.setPort(2222); validServiceDuplicate.setPort(2222); @@ -386,6 +393,33 @@ public class NsdManagerTest { otherSubtypeUpdate.setPort(2222); dotSyntaxSubtypeUpdate.setPort(2222); + NsdServiceInfo invalidMissingHostnameWithAddresses = new NsdServiceInfo(null, null); + invalidMissingHostnameWithAddresses.setHostAddresses( + List.of( + InetAddress.parseNumericAddress("192.168.82.14"), + InetAddress.parseNumericAddress("2001::1"))); + + NsdServiceInfo validCustomHostWithAddresses = new NsdServiceInfo(null, null); + validCustomHostWithAddresses.setHostname("a_host"); + validCustomHostWithAddresses.setHostAddresses( + List.of( + InetAddress.parseNumericAddress("192.168.82.14"), + InetAddress.parseNumericAddress("2001::1"))); + + NsdServiceInfo validServiceWithCustomHostAndAddresses = + new NsdServiceInfo("a_name", "_a_type._tcp"); + validServiceWithCustomHostAndAddresses.setPort(2222); + validServiceWithCustomHostAndAddresses.setHostname("a_host"); + validServiceWithCustomHostAndAddresses.setHostAddresses( + List.of( + InetAddress.parseNumericAddress("192.168.82.14"), + InetAddress.parseNumericAddress("2001::1"))); + + NsdServiceInfo validServiceWithCustomHostNoAddresses = + new NsdServiceInfo("a_name", "_a_type._tcp"); + validServiceWithCustomHostNoAddresses.setPort(2222); + validServiceWithCustomHostNoAddresses.setHostname("a_host"); + // Service registration // - invalid arguments mustFail(() -> { manager.unregisterService(null); }); @@ -394,6 +428,8 @@ public class NsdManagerTest { mustFail(() -> { manager.registerService(invalidService, PROTOCOL, listener1); }); mustFail(() -> { manager.registerService(validService, -1, listener1); }); mustFail(() -> { manager.registerService(validService, PROTOCOL, null); }); + mustFail(() -> { + manager.registerService(invalidMissingHostnameWithAddresses, PROTOCOL, listener1); }); manager.registerService(validService, PROTOCOL, listener1); // - update without subtype is not allowed mustFail(() -> { manager.registerService(validServiceDuplicate, PROTOCOL, listener1); }); @@ -415,6 +451,15 @@ public class NsdManagerTest { // TODO: make listener immediately reusable //mustFail(() -> { manager.unregisterService(listener1); }); //manager.registerService(validService, PROTOCOL, listener1); + // - registering a custom host without a service is valid + manager.registerService(validCustomHostWithAddresses, PROTOCOL, listener4); + manager.unregisterService(listener4); + // - registering a service with a custom host is valid + manager.registerService(validServiceWithCustomHostAndAddresses, PROTOCOL, listener5); + manager.unregisterService(listener5); + // - registering a service with a custom host with no addresses is valid + manager.registerService(validServiceWithCustomHostNoAddresses, PROTOCOL, listener6); + manager.unregisterService(listener6); // Discover service // - invalid arguments diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java index a17197e4bb72c3c0e2026438eb2983ce5e62da4c..b60f0b4cff82843914b5e7c304157306d6d5f74e 100644 --- a/tests/unit/java/com/android/server/NsdServiceTest.java +++ b/tests/unit/java/com/android/server/NsdServiceTest.java @@ -1139,7 +1139,7 @@ public class NsdServiceTest { verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(s -> "Instance".equals(s.getServiceName()) && SERVICE_TYPE.equals(s.getServiceType()) - && s.getSubtypes().equals(Set.of("_subtype"))), any()); + && s.getSubtypes().equals(Set.of("_subtype"))), any(), anyInt()); final DiscoveryListener discListener = mock(DiscoveryListener.class); client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener); @@ -1246,7 +1246,7 @@ public class NsdServiceTest { final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class); verify(mAdvertiser).addOrUpdateService(serviceIdCaptor.capture(), - argThat(info -> matches(info, regInfo)), any()); + argThat(info -> matches(info, regInfo)), any(), anyInt()); client.unregisterService(regListenerWithoutFeature); waitForIdle(); @@ -1275,7 +1275,7 @@ public class NsdServiceTest { service1.setHostAddresses(List.of(parseNumericAddress("2001:db8::123"))); service1.setPort(1234); final NsdServiceInfo service2 = new NsdServiceInfo(SERVICE_NAME, "_type2._tcp"); - service2.setHostAddresses(List.of(parseNumericAddress("2001:db8::123"))); + service1.setHostAddresses(List.of(parseNumericAddress("2001:db8::123"))); service2.setPort(1234); client.discoverServices(service1.getServiceType(), @@ -1307,9 +1307,9 @@ public class NsdServiceTest { // The advertiser is enabled for _type2 but not _type1 verify(mAdvertiser, never()).addOrUpdateService(anyInt(), - argThat(info -> matches(info, service1)), any()); + argThat(info -> matches(info, service1)), any(), anyInt()); verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(info -> matches(info, service2)), - any()); + any(), anyInt()); } @Test @@ -1334,7 +1334,7 @@ public class NsdServiceTest { verify(mSocketProvider).startMonitoringSockets(); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); verify(mAdvertiser).addOrUpdateService(idCaptor.capture(), argThat(info -> - matches(info, regInfo)), any()); + matches(info, regInfo)), any(), anyInt()); // Verify onServiceRegistered callback final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); @@ -1382,7 +1382,7 @@ public class NsdServiceTest { client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); waitForIdle(); - verify(mAdvertiser, never()).addOrUpdateService(anyInt(), any(), any()); + verify(mAdvertiser, never()).addOrUpdateService(anyInt(), any(), any(), anyInt()); verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed( argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR)); @@ -1411,8 +1411,12 @@ public class NsdServiceTest { waitForIdle(); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); // Service name is truncated to 63 characters - verify(mAdvertiser).addOrUpdateService(idCaptor.capture(), - argThat(info -> info.getServiceName().equals("a".repeat(63))), any()); + verify(mAdvertiser) + .addOrUpdateService( + idCaptor.capture(), + argThat(info -> info.getServiceName().equals("a".repeat(63))), + any(), + anyInt()); // Verify onServiceRegistered callback final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); @@ -1510,7 +1514,7 @@ public class NsdServiceTest { client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); waitForIdle(); verify(mSocketProvider).startMonitoringSockets(); - verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any()); + verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), anyInt()); // Verify the discovery uses MdnsDiscoveryManager final DiscoveryListener discListener = mock(DiscoveryListener.class); @@ -1543,7 +1547,7 @@ public class NsdServiceTest { client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); waitForIdle(); verify(mSocketProvider).startMonitoringSockets(); - verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any()); + verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), anyInt()); final Network wifiNetwork1 = new Network(123); final Network wifiNetwork2 = new Network(124); diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt index 5c04362168be31020038276afebc3985ca4ce3b4..f753c9324696a0617b96b7a16713b1bdb7ab7a44 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt @@ -32,6 +32,7 @@ import com.android.connectivity.resources.R import com.android.net.module.util.SharedLog import com.android.server.connectivity.ConnectivityResources import com.android.server.connectivity.mdns.MdnsAdvertiser.AdvertiserCallback +import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE import com.android.server.connectivity.mdns.MdnsSocketProvider.SocketCallback import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.DevSdkIgnoreRunner @@ -76,6 +77,7 @@ private const val TEST_SUBTYPE = "_subtype" private const val TEST_SUBTYPE2 = "_subtype2" private val TEST_INTERFACE1 = "test_iface1" private val TEST_INTERFACE2 = "test_iface2" +private val TEST_CLIENT_UID_1 = 10010 private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03) private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04) private val DEFAULT_ADVERTISING_OPTION = MdnsAdvertisingOptions.getDefaultOptions() @@ -227,7 +229,7 @@ class MdnsAdvertiserTest { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags, context) postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture()) @@ -254,7 +256,10 @@ class MdnsAdvertiserTest { verify(cb).onOffloadStartOrUpdate(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE)) // Service is conflicted. - postSync { intAdvCbCaptor.value.onServiceConflict(mockInterfaceAdvertiser1, SERVICE_ID_1) } + postSync { + intAdvCbCaptor.value + .onServiceConflict(mockInterfaceAdvertiser1, SERVICE_ID_1, CONFLICT_SERVICE) + } // Verify the metrics data doReturn(25).`when`(mockInterfaceAdvertiser1).getServiceRepliedRequestsCount(SERVICE_ID_1) @@ -289,7 +294,7 @@ class MdnsAdvertiserTest { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags, context) postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE_SUBTYPE, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE_SUBTYPE.network), @@ -327,9 +332,18 @@ class MdnsAdvertiserTest { argThat { it.matches(ALL_NETWORKS_SERVICE_SUBTYPE) }) // Services are conflicted. - postSync { intAdvCbCaptor1.value.onServiceConflict(mockInterfaceAdvertiser1, SERVICE_ID_1) } - postSync { intAdvCbCaptor1.value.onServiceConflict(mockInterfaceAdvertiser1, SERVICE_ID_1) } - postSync { intAdvCbCaptor2.value.onServiceConflict(mockInterfaceAdvertiser2, SERVICE_ID_1) } + postSync { + intAdvCbCaptor1.value + .onServiceConflict(mockInterfaceAdvertiser1, SERVICE_ID_1, CONFLICT_SERVICE) + } + postSync { + intAdvCbCaptor1.value + .onServiceConflict(mockInterfaceAdvertiser1, SERVICE_ID_1, CONFLICT_SERVICE) + } + postSync { + intAdvCbCaptor2.value + .onServiceConflict(mockInterfaceAdvertiser2, SERVICE_ID_1, CONFLICT_SERVICE) + } // Verify the metrics data doReturn(10).`when`(mockInterfaceAdvertiser1).getServiceRepliedRequestsCount(SERVICE_ID_1) @@ -361,18 +375,19 @@ class MdnsAdvertiserTest { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags, context) postSync { - advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, DEFAULT_ADVERTISING_OPTION) + advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, DEFAULT_ADVERTISING_OPTION, + TEST_CLIENT_UID_1) advertiser.addOrUpdateService(SERVICE_ID_2, NsdServiceInfo("TestService2", "_PRIORITYTEST._udp").apply { port = 12345 hostAddresses = listOf(TEST_ADDR) - }, DEFAULT_ADVERTISING_OPTION) + }, DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) advertiser.addOrUpdateService( SERVICE_ID_3, NsdServiceInfo("TestService3", "_notprioritized._tcp").apply { port = 12345 hostAddresses = listOf(TEST_ADDR) - }, DEFAULT_ADVERTISING_OPTION) + }, DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) @@ -419,7 +434,7 @@ class MdnsAdvertiserTest { val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags, context) postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture()) @@ -427,18 +442,18 @@ class MdnsAdvertiserTest { // Register a service with the same name on all networks (name conflict) postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) val allNetSocketCb = allNetSocketCbCaptor.value postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_1, LONG_SERVICE_1, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } postSync { advertiser.addOrUpdateService(CASE_INSENSITIVE_TEST_SERVICE_ID, - ALL_NETWORKS_SERVICE_2, DEFAULT_ADVERTISING_OPTION) } + ALL_NETWORKS_SERVICE_2, DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } // Callbacks for matching network and all networks both get the socket postSync { @@ -508,7 +523,7 @@ class MdnsAdvertiserTest { MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags, context) postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) verify(socketProvider).requestSocket(eq(null), socketCbCaptor.capture()) @@ -523,16 +538,17 @@ class MdnsAdvertiserTest { // Update with serviceId that is not registered yet should fail postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE_SUBTYPE, - updateOptions) } + updateOptions, TEST_CLIENT_UID_1) } verify(cb).onRegisterServiceFailed(SERVICE_ID_2, NsdManager.FAILURE_INTERNAL_ERROR) // Update service with different NsdServiceInfo should fail - postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1_SUBTYPE, updateOptions) } + postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1_SUBTYPE, updateOptions, + TEST_CLIENT_UID_1) } verify(cb).onRegisterServiceFailed(SERVICE_ID_1, NsdManager.FAILURE_INTERNAL_ERROR) // Update service with same NsdServiceInfo but different subType should succeed postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE_SUBTYPE, - updateOptions) } + updateOptions, TEST_CLIENT_UID_1) } verify(mockInterfaceAdvertiser1).updateService(eq(SERVICE_ID_1), eq(setOf(TEST_SUBTYPE))) // Newly created MdnsInterfaceAdvertiser will get addService() call. @@ -547,7 +563,7 @@ class MdnsAdvertiserTest { MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags, context) verify(mockDeps, times(1)).generateHostname() postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, - DEFAULT_ADVERTISING_OPTION) } + DEFAULT_ADVERTISING_OPTION, TEST_CLIENT_UID_1) } postSync { advertiser.removeService(SERVICE_ID_1) } verify(mockDeps, times(2)).generateHostname() } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt index 0e5cc50d7ce8c0a3a2124b7b13739eaa184e0147..0637ad11c56813a23f2100078276add6cdf58c3d 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt @@ -18,6 +18,7 @@ package com.android.server.connectivity.mdns import android.net.InetAddresses.parseNumericAddress import android.net.LinkAddress +import android.net.nsd.NsdManager import android.net.nsd.NsdServiceInfo import android.os.Build import android.os.HandlerThread @@ -26,6 +27,7 @@ import com.android.net.module.util.SharedLog import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo +import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo @@ -347,7 +349,8 @@ class MdnsInterfaceAdvertiserTest { @Test fun testConflict() { addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1) - doReturn(setOf(TEST_SERVICE_ID_1)).`when`(repository).getConflictingServices(any()) + doReturn(mapOf(TEST_SERVICE_ID_1 to CONFLICT_SERVICE)) + .`when`(repository).getConflictingServices(any()) // Reply obtained with: // scapy.raw(scapy.DNS( @@ -373,7 +376,7 @@ class MdnsInterfaceAdvertiserTest { } thread.waitForIdle(TIMEOUT_MS) - verify(cb).onServiceConflict(advertiser, TEST_SERVICE_ID_1) + verify(cb).onServiceConflict(advertiser, TEST_SERVICE_ID_1, CONFLICT_SERVICE) } @Test diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index 06f12fe4608568bb436bfdeab7d7d95411a9c0a7..fd8d98ba008a09cb0938063cda49ea00adc2d141 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -22,6 +22,8 @@ import android.net.nsd.NsdServiceInfo import android.os.Build import android.os.HandlerThread import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo +import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST +import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR @@ -52,6 +54,9 @@ import org.junit.runner.RunWith private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_2 = 43 private const val TEST_SERVICE_ID_3 = 44 +private const val TEST_CUSTOM_HOST_ID_1 = 45 +private const val TEST_CUSTOM_HOST_ID_2 = 46 +private const val TEST_SERVICE_CUSTOM_HOST_ID_1 = 48 private const val TEST_PORT = 12345 private const val TEST_SUBTYPE = "_subtype" private const val TEST_SUBTYPE2 = "_subtype2" @@ -86,6 +91,26 @@ private val TEST_SERVICE_3 = NsdServiceInfo().apply { port = TEST_PORT } +private val TEST_CUSTOM_HOST_1 = NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf(parseNumericAddress("2001:db8::1"), parseNumericAddress("2001:db8::2")) +} + +private val TEST_CUSTOM_HOST_1_NAME = arrayOf("TestHost", "local") + +private val TEST_CUSTOM_HOST_2 = NsdServiceInfo().apply { + hostname = "OtherTestHost" + hostAddresses = listOf(parseNumericAddress("2001:db8::3"), parseNumericAddress("2001:db8::4")) +} + +private val TEST_SERVICE_CUSTOM_HOST_1 = NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf(parseNumericAddress("2001:db8::1")) + serviceType = "_testservice._tcp" + serviceName = "TestService" + port = TEST_PORT +} + @RunWith(DevSdkIgnoreRunner::class) @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) class MdnsRecordRepositoryTest { @@ -569,6 +594,92 @@ class MdnsRecordRepositoryTest { ), reply.additionalAnswers) } + + @Test + fun testGetReply_ptrQuestionForServiceWithCustomHost_customHostUsedInAdditionalAnswers() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.initWithService(TEST_SERVICE_CUSTOM_HOST_ID_1, TEST_SERVICE_CUSTOM_HOST_1, + setOf(TEST_SUBTYPE, TEST_SUBTYPE2)) + val src = InetSocketAddress(parseNumericAddress("fe80::1234"), 5353) + val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local") + + val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local")) + val reply = repository.getReply(query, src) + + assertNotNull(reply) + assertEquals(listOf( + MdnsPointerRecord( + arrayOf("_testservice", "_tcp", "local"), + 0L, false, LONG_TTL, serviceName)), + reply.answers) + assertEquals(listOf( + MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()), + MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, + 0, 0, TEST_PORT, TEST_CUSTOM_HOST_1_NAME), + MdnsInetAddressRecord( + TEST_CUSTOM_HOST_1_NAME, 0L, true, SHORT_TTL, + parseNumericAddress("2001:db8::1")), + MdnsNsecRecord(serviceName, 0L, true, LONG_TTL, serviceName /* nextDomain */, + intArrayOf(TYPE_TXT, TYPE_SRV)), + MdnsNsecRecord(TEST_CUSTOM_HOST_1_NAME, 0L, true, SHORT_TTL, + TEST_CUSTOM_HOST_1_NAME /* nextDomain */, + intArrayOf(TYPE_AAAA)), + ), reply.additionalAnswers) + } + + @Test + fun testGetReply_ptrQuestionForServicesWithSameCustomHost_customHostUsedInAdditionalAnswers() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + val serviceWithCustomHost1 = NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf( + parseNumericAddress("2001:db8::1"), + parseNumericAddress("192.0.2.1")) + serviceType = "_testservice._tcp" + serviceName = "TestService1" + port = TEST_PORT + } + val serviceWithCustomHost2 = NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf( + parseNumericAddress("2001:db8::1"), + parseNumericAddress("2001:db8::3")) + } + repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, serviceWithCustomHost1) + repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, serviceWithCustomHost2) + val src = InetSocketAddress(parseNumericAddress("fe80::1234"), 5353) + val serviceName = arrayOf("TestService1", "_testservice", "_tcp", "local") + + val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local")) + val reply = repository.getReply(query, src) + + assertNotNull(reply) + assertEquals(listOf( + MdnsPointerRecord( + arrayOf("_testservice", "_tcp", "local"), + 0L, false, LONG_TTL, serviceName)), + reply.answers) + assertEquals(listOf( + MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()), + MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, + 0, 0, TEST_PORT, TEST_CUSTOM_HOST_1_NAME), + MdnsInetAddressRecord( + TEST_CUSTOM_HOST_1_NAME, 0L, true, SHORT_TTL, + parseNumericAddress("2001:db8::1")), + MdnsInetAddressRecord( + TEST_CUSTOM_HOST_1_NAME, 0L, true, SHORT_TTL, + parseNumericAddress("192.0.2.1")), + MdnsInetAddressRecord( + TEST_CUSTOM_HOST_1_NAME, 0L, true, SHORT_TTL, + parseNumericAddress("2001:db8::3")), + MdnsNsecRecord(serviceName, 0L, true, LONG_TTL, serviceName /* nextDomain */, + intArrayOf(TYPE_TXT, TYPE_SRV)), + MdnsNsecRecord(TEST_CUSTOM_HOST_1_NAME, 0L, true, SHORT_TTL, + TEST_CUSTOM_HOST_1_NAME /* nextDomain */, + intArrayOf(TYPE_A, TYPE_AAAA)), + ), reply.additionalAnswers) + } + @Test fun testGetReply_singleSubtypePtrQuestion_returnsSrvTxtAddressNsecRecords() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) @@ -707,6 +818,90 @@ class MdnsRecordRepositoryTest { reply.additionalAnswers) } + @Test + fun testGetReply_AAAAQuestionForCustomHost_returnsAAAARecords() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.initWithService( + TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, subtypes = setOf(), + listOf(LinkAddress(parseNumericAddress("192.0.2.111"), 24))) + repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2) + val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353) + + val query = makeQuery(TYPE_AAAA to TEST_CUSTOM_HOST_1_NAME) + val reply = repository.getReply(query, src) + + assertNotNull(reply) + assertEquals(listOf( + MdnsInetAddressRecord(TEST_CUSTOM_HOST_1_NAME, + 0, false, LONG_TTL, parseNumericAddress("2001:db8::1")), + MdnsInetAddressRecord(TEST_CUSTOM_HOST_1_NAME, + 0, false, LONG_TTL, parseNumericAddress("2001:db8::2"))), + reply.answers) + assertEquals( + listOf(MdnsNsecRecord(TEST_CUSTOM_HOST_1_NAME, + 0L, true, SHORT_TTL, + TEST_CUSTOM_HOST_1_NAME /* nextDomain */, + intArrayOf(TYPE_AAAA))), + reply.additionalAnswers) + } + + + @Test + fun testGetReply_AAAAQuestionForCustomHostInMultipleRegistrations_returnsAAAARecords() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + + repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_1, NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf( + parseNumericAddress("2001:db8::1"), + parseNumericAddress("2001:db8::2")) + }) + repository.addServiceAndFinishProbing(TEST_CUSTOM_HOST_ID_2, NsdServiceInfo().apply { + hostname = "TestHost" + hostAddresses = listOf( + parseNumericAddress("2001:db8::1"), + parseNumericAddress("2001:db8::3")) + }) + val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353) + + val query = makeQuery(TYPE_AAAA to TEST_CUSTOM_HOST_1_NAME) + val reply = repository.getReply(query, src) + + assertNotNull(reply) + assertEquals(listOf( + MdnsInetAddressRecord(TEST_CUSTOM_HOST_1_NAME, + 0, false, LONG_TTL, parseNumericAddress("2001:db8::1")), + MdnsInetAddressRecord(TEST_CUSTOM_HOST_1_NAME, + 0, false, LONG_TTL, parseNumericAddress("2001:db8::2")), + MdnsInetAddressRecord(TEST_CUSTOM_HOST_1_NAME, + 0, false, LONG_TTL, parseNumericAddress("2001:db8::3"))), + reply.answers) + assertEquals( + listOf(MdnsNsecRecord(TEST_CUSTOM_HOST_1_NAME, + 0L, true, SHORT_TTL, + TEST_CUSTOM_HOST_1_NAME /* nextDomain */, + intArrayOf(TYPE_AAAA))), + reply.additionalAnswers) + } + + @Test + fun testGetReply_customHostRemoved_noAnswerToAAAAQuestion() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.initWithService( + TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1, subtypes = setOf(), + listOf(LinkAddress(parseNumericAddress("192.0.2.111"), 24))) + repository.addService(TEST_SERVICE_CUSTOM_HOST_ID_1, TEST_SERVICE_CUSTOM_HOST_1) + repository.removeService(TEST_CUSTOM_HOST_ID_1) + repository.removeService(TEST_SERVICE_CUSTOM_HOST_ID_1) + + val src = InetSocketAddress(parseNumericAddress("fe80::123"), 5353) + + val query = makeQuery(TYPE_AAAA to TEST_CUSTOM_HOST_1_NAME) + val reply = repository.getReply(query, src) + + assertNull(reply) + } + @Test fun testGetReply_ptrAndSrvQuestions_doesNotReturnSrvRecordInAdditionalAnswerSection() { val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) @@ -815,7 +1010,10 @@ class MdnsRecordRepositoryTest { emptyList() /* authorityRecords */, emptyList() /* additionalRecords */) - assertEquals(setOf(TEST_SERVICE_ID_1, TEST_SERVICE_ID_2), + assertEquals( + mapOf( + TEST_SERVICE_ID_1 to CONFLICT_SERVICE, + TEST_SERVICE_ID_2 to CONFLICT_SERVICE), repository.getConflictingServices(packet)) } @@ -843,8 +1041,131 @@ class MdnsRecordRepositoryTest { emptyList() /* authorityRecords */, emptyList() /* additionalRecords */) - assertEquals(setOf(TEST_SERVICE_ID_1, TEST_SERVICE_ID_2), - repository.getConflictingServices(packet)) + assertEquals( + mapOf(TEST_SERVICE_ID_1 to CONFLICT_SERVICE, + TEST_SERVICE_ID_2 to CONFLICT_SERVICE), + repository.getConflictingServices(packet)) + } + + @Test + fun testGetConflictingServices_customHosts_differentAddresses() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1) + repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2) + + val packet = MdnsPacket( + 0, /* flags */ + emptyList(), /* questions */ + listOf( + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::5")), + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::6")), + ) /* answers */, + emptyList() /* authorityRecords */, + emptyList() /* additionalRecords */) + + assertEquals(mapOf(TEST_CUSTOM_HOST_ID_1 to CONFLICT_HOST), + repository.getConflictingServices(packet)) + } + + @Test + fun testGetConflictingServices_customHosts_moreAddressesThanUs_conflict() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1) + repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2) + + val packet = MdnsPacket( + 0, /* flags */ + emptyList(), /* questions */ + listOf( + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::1")), + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::2")), + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::3")), + ) /* answers */, + emptyList() /* authorityRecords */, + emptyList() /* additionalRecords */) + + assertEquals(mapOf(TEST_CUSTOM_HOST_ID_1 to CONFLICT_HOST), + repository.getConflictingServices(packet)) + } + + @Test + fun testGetConflictingServices_customHostsReplyHasFewerAddressesThanUs_noConflict() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1) + repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2) + + val packet = MdnsPacket( + 0, /* flags */ + emptyList(), /* questions */ + listOf( + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::2")), + ) /* answers */, + emptyList() /* authorityRecords */, + emptyList() /* additionalRecords */) + + assertEquals(emptyMap(), + repository.getConflictingServices(packet)) + } + + @Test + fun testGetConflictingServices_customHostsReplyHasIdenticalHosts_noConflict() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1) + repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2) + + val packet = MdnsPacket( + 0, /* flags */ + emptyList(), /* questions */ + listOf( + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::1")), + MdnsInetAddressRecord(arrayOf("TestHost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::2")), + ) /* answers */, + emptyList() /* authorityRecords */, + emptyList() /* additionalRecords */) + + assertEquals(emptyMap(), + repository.getConflictingServices(packet)) + } + + + @Test + fun testGetConflictingServices_customHostsCaseInsensitiveReplyHasIdenticalHosts_noConflict() { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags()) + repository.addService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1) + repository.addService(TEST_CUSTOM_HOST_ID_2, TEST_CUSTOM_HOST_2) + + val packet = MdnsPacket( + 0, /* flags */ + emptyList(), /* questions */ + listOf( + MdnsInetAddressRecord(arrayOf("TESTHOST", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::1")), + MdnsInetAddressRecord(arrayOf("testhost", "local"), + 0L /* receiptTimeMillis */, true /* cacheFlush */, + 0L /* ttlMillis */, parseNumericAddress("2001:db8::2")), + ) /* answers */, + emptyList() /* authorityRecords */, + emptyList() /* additionalRecords */) + + assertEquals(emptyMap(), + repository.getConflictingServices(packet)) } @Test @@ -873,7 +1194,7 @@ class MdnsRecordRepositoryTest { emptyList() /* additionalRecords */) // Above records are identical to the actual registrations: no conflict - assertEquals(emptySet(), repository.getConflictingServices(packet)) + assertEquals(emptyMap(), repository.getConflictingServices(packet)) } @Test @@ -902,7 +1223,7 @@ class MdnsRecordRepositoryTest { emptyList() /* additionalRecords */) // Above records are identical to the actual registrations: no conflict - assertEquals(emptySet(), repository.getConflictingServices(packet)) + assertEquals(emptyMap(), repository.getConflictingServices(packet)) } @Test