diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java index ec6af9b6d6af9c10124c8310fcdc4471f83a1c88..f9ee0df34a4954cdfeac8a2b2f6b96787f6087af 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java @@ -42,7 +42,7 @@ import java.util.Objects; * to their default value (0, false or null). */ public class MdnsServiceCache { - private static class CacheKey { + static class CacheKey { @NonNull final String mLowercaseServiceType; @NonNull final SocketKey mSocketKey; @@ -72,6 +72,12 @@ public class MdnsServiceCache { */ @NonNull private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>(); + /** + * A map of service expire callbacks. Key is composed of service type and socket and value is + * the callback listener. + */ + @NonNull + private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>(); @NonNull private final Handler mHandler; @@ -82,17 +88,14 @@ public class MdnsServiceCache { /** * Get the cache services which are queried from given service type and socket. * - * @param serviceType the target service type. - * @param socketKey the target socket + * @param cacheKey the target CacheKey. * @return the set of services which matches the given service type. */ @NonNull - public List<MdnsResponse> getCachedServices(@NonNull String serviceType, - @NonNull SocketKey socketKey) { + public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); - final CacheKey key = new CacheKey(serviceType, socketKey); - return mCachedServices.containsKey(key) - ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key))) + return mCachedServices.containsKey(cacheKey) + ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey))) : Collections.emptyList(); } @@ -117,16 +120,13 @@ public class MdnsServiceCache { * Get the cache service. * * @param serviceName the target service name. - * @param serviceType the target service type. - * @param socketKey the target socket + * @param cacheKey the target CacheKey. * @return the service which matches given conditions. */ @Nullable - public MdnsResponse getCachedService(@NonNull String serviceName, - @NonNull String serviceType, @NonNull SocketKey socketKey) { + public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); - final List<MdnsResponse> responses = - mCachedServices.get(new CacheKey(serviceType, socketKey)); + final List<MdnsResponse> responses = mCachedServices.get(cacheKey); if (responses == null) { return null; } @@ -137,15 +137,13 @@ public class MdnsServiceCache { /** * Add or update a service. * - * @param serviceType the service type. - * @param socketKey the target socket + * @param cacheKey the target CacheKey. * @param response the response of the discovered service. */ - public void addOrUpdateService(@NonNull String serviceType, @NonNull SocketKey socketKey, - @NonNull MdnsResponse response) { + public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) { ensureRunningOnHandlerThread(mHandler); final List<MdnsResponse> responses = mCachedServices.computeIfAbsent( - new CacheKey(serviceType, socketKey), key -> new ArrayList<>()); + cacheKey, key -> new ArrayList<>()); // Remove existing service if present. final MdnsResponse existing = findMatchedResponse(responses, response.getServiceInstanceName()); @@ -157,15 +155,12 @@ public class MdnsServiceCache { * Remove a service which matches the given service name, type and socket. * * @param serviceName the target service name. - * @param serviceType the target service type. - * @param socketKey the target socket. + * @param cacheKey the target CacheKey. */ @Nullable - public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType, - @NonNull SocketKey socketKey) { + public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) { ensureRunningOnHandlerThread(mHandler); - final List<MdnsResponse> responses = - mCachedServices.get(new CacheKey(serviceType, socketKey)); + final List<MdnsResponse> responses = mCachedServices.get(cacheKey); if (responses == null) { return null; } @@ -180,5 +175,37 @@ public class MdnsServiceCache { return null; } + /** + * Register a callback to listen to service expiration. + * + * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient + * relies on this. + * + * @param cacheKey the target CacheKey. + * @param callback the callback that notify the service is expired. + */ + public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey, + @NonNull ServiceExpiredCallback callback) { + ensureRunningOnHandlerThread(mHandler); + mCallbacks.put(cacheKey, callback); + } + + /** + * Unregister the service expired callback. + * + * @param cacheKey the CacheKey that is registered to listen service expiration before. + */ + public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) { + ensureRunningOnHandlerThread(mHandler); + mCallbacks.remove(cacheKey); + } + + /*** Callbacks for listening service expiration */ + public interface ServiceExpiredCallback { + /*** Notify the service is expired */ + void onServiceRecordExpired(@NonNull MdnsResponse previousResponse, + @Nullable MdnsResponse newResponse); + } + // TODO: check ttl expiration for each service and notify to the clients. } diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java index bbe8f4c5553987d51768257eec0182eb4fbf6248..0a031864a5f4721f8a52b7a7e1a65d83854d5d02 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java @@ -16,6 +16,7 @@ package com.android.server.connectivity.mdns; +import static com.android.server.connectivity.mdns.MdnsServiceCache.ServiceExpiredCallback; import static com.android.server.connectivity.mdns.MdnsServiceCache.findMatchedResponse; import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock; import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread; @@ -71,6 +72,15 @@ public class MdnsServiceTypeClient { * The service caches for each socket. It should be accessed from looper thread only. */ @NonNull private final MdnsServiceCache serviceCache; + @NonNull private final MdnsServiceCache.CacheKey cacheKey; + @NonNull private final ServiceExpiredCallback serviceExpiredCallback = + new ServiceExpiredCallback() { + @Override + public void onServiceRecordExpired(@NonNull MdnsResponse previousResponse, + @Nullable MdnsResponse newResponse) { + notifyRemovedServiceToListeners(previousResponse, "Service record expired"); + } + }; private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners = new ArrayMap<>(); private final boolean removeServiceAfterTtlExpires = @@ -225,6 +235,16 @@ public class MdnsServiceTypeClient { this.dependencies = dependencies; this.serviceCache = serviceCache; this.mdnsQueryScheduler = new MdnsQueryScheduler(); + this.cacheKey = new MdnsServiceCache.CacheKey(serviceType, socketKey); + } + + /** + * Do the cleanup of the MdnsServiceTypeClient + */ + private void shutDown() { + removeScheduledTask(); + mdnsQueryScheduler.cancelScheduledRun(); + serviceCache.unregisterServiceExpiredCallback(cacheKey); } private static MdnsServiceInfo buildMdnsServiceInfoFromResponse( @@ -293,7 +313,7 @@ public class MdnsServiceTypeClient { boolean hadReply = false; if (listeners.put(listener, searchOptions) == null) { for (MdnsResponse existingResponse : - serviceCache.getCachedServices(serviceType, socketKey)) { + serviceCache.getCachedServices(cacheKey)) { if (!responseMatchesOptions(existingResponse, searchOptions)) continue; final MdnsServiceInfo info = buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels); @@ -341,6 +361,8 @@ public class MdnsServiceTypeClient { servicesToResolve.size() < listeners.size() /* sendDiscoveryQueries */); executor.submit(queryTask); } + + serviceCache.registerServiceExpiredCallback(cacheKey, serviceExpiredCallback); } /** @@ -390,8 +412,7 @@ public class MdnsServiceTypeClient { return listeners.isEmpty(); } if (listeners.isEmpty()) { - removeScheduledTask(); - mdnsQueryScheduler.cancelScheduledRun(); + shutDown(); } return listeners.isEmpty(); } @@ -404,8 +425,7 @@ public class MdnsServiceTypeClient { ensureRunningOnHandlerThread(handler); // Augment the list of current known responses, and generated responses for resolve // requests if there is no known response - final List<MdnsResponse> cachedList = - serviceCache.getCachedServices(serviceType, socketKey); + final List<MdnsResponse> cachedList = serviceCache.getCachedServices(cacheKey); final List<MdnsResponse> currentList = new ArrayList<>(cachedList); List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey); for (MdnsResponse additionalResponse : additionalResponses) { @@ -432,7 +452,7 @@ public class MdnsServiceTypeClient { } else if (findMatchedResponse(cachedList, serviceInstanceName) != null) { // If the response is not modified and already in the cache. The cache will // need to be updated to refresh the last receipt time. - serviceCache.addOrUpdateService(serviceType, socketKey, response); + serviceCache.addOrUpdateService(cacheKey, response); } } if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)) { @@ -458,44 +478,50 @@ public class MdnsServiceTypeClient { } } - /** Notify all services are removed because the socket is destroyed. */ - public void notifySocketDestroyed() { - ensureRunningOnHandlerThread(handler); - for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) { - final String name = response.getServiceInstanceName(); - if (name == null) continue; - for (int i = 0; i < listeners.size(); i++) { - if (!responseMatchesOptions(response, listeners.valueAt(i))) continue; - final MdnsServiceBrowserListener listener = listeners.keyAt(i); - final MdnsServiceInfo serviceInfo = - buildMdnsServiceInfoFromResponse(response, serviceTypeLabels); + private void notifyRemovedServiceToListeners(@NonNull MdnsResponse response, + @NonNull String message) { + for (int i = 0; i < listeners.size(); i++) { + if (!responseMatchesOptions(response, listeners.valueAt(i))) continue; + final MdnsServiceBrowserListener listener = listeners.keyAt(i); + if (response.getServiceInstanceName() != null) { + final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse( + response, serviceTypeLabels); if (response.isComplete()) { - sharedLog.log("Socket destroyed. onServiceRemoved: " + name); + sharedLog.log(message + ". onServiceRemoved: " + serviceInfo); listener.onServiceRemoved(serviceInfo); } - sharedLog.log("Socket destroyed. onServiceNameRemoved: " + name); + sharedLog.log(message + ". onServiceNameRemoved: " + serviceInfo); listener.onServiceNameRemoved(serviceInfo); } } - removeScheduledTask(); - mdnsQueryScheduler.cancelScheduledRun(); + } + + /** Notify all services are removed because the socket is destroyed. */ + public void notifySocketDestroyed() { + ensureRunningOnHandlerThread(handler); + for (MdnsResponse response : serviceCache.getCachedServices(cacheKey)) { + final String name = response.getServiceInstanceName(); + if (name == null) continue; + notifyRemovedServiceToListeners(response, "Socket destroyed"); + } + shutDown(); } private void onResponseModified(@NonNull MdnsResponse response) { final String serviceInstanceName = response.getServiceInstanceName(); final MdnsResponse currentResponse = - serviceCache.getCachedService(serviceInstanceName, serviceType, socketKey); + serviceCache.getCachedService(serviceInstanceName, cacheKey); boolean newServiceFound = false; boolean serviceBecomesComplete = false; if (currentResponse == null) { newServiceFound = true; if (serviceInstanceName != null) { - serviceCache.addOrUpdateService(serviceType, socketKey, response); + serviceCache.addOrUpdateService(cacheKey, response); } } else { boolean before = currentResponse.isComplete(); - serviceCache.addOrUpdateService(serviceType, socketKey, response); + serviceCache.addOrUpdateService(cacheKey, response); boolean after = response.isComplete(); serviceBecomesComplete = !before && after; } @@ -529,22 +555,11 @@ public class MdnsServiceTypeClient { private void onGoodbyeReceived(@Nullable String serviceInstanceName) { final MdnsResponse response = - serviceCache.removeService(serviceInstanceName, serviceType, socketKey); + serviceCache.removeService(serviceInstanceName, cacheKey); if (response == null) { return; } - for (int i = 0; i < listeners.size(); i++) { - if (!responseMatchesOptions(response, listeners.valueAt(i))) continue; - final MdnsServiceBrowserListener listener = listeners.keyAt(i); - final MdnsServiceInfo serviceInfo = - buildMdnsServiceInfoFromResponse(response, serviceTypeLabels); - if (response.isComplete()) { - sharedLog.log("onServiceRemoved: " + serviceInfo); - listener.onServiceRemoved(serviceInfo); - } - sharedLog.log("onServiceNameRemoved: " + serviceInfo); - listener.onServiceNameRemoved(serviceInfo); - } + notifyRemovedServiceToListeners(response, "Goodbye received"); } private boolean shouldRemoveServiceAfterTtlExpires() { @@ -567,7 +582,7 @@ public class MdnsServiceTypeClient { continue; } MdnsResponse knownResponse = - serviceCache.getCachedService(resolveName, serviceType, socketKey); + serviceCache.getCachedService(resolveName, cacheKey); if (knownResponse == null) { final ArrayList<String> instanceFullName = new ArrayList<>( serviceTypeLabels.length + 1); @@ -585,36 +600,18 @@ public class MdnsServiceTypeClient { private void tryRemoveServiceAfterTtlExpires() { if (!shouldRemoveServiceAfterTtlExpires()) return; - Iterator<MdnsResponse> iter = - serviceCache.getCachedServices(serviceType, socketKey).iterator(); + final Iterator<MdnsResponse> iter = serviceCache.getCachedServices(cacheKey).iterator(); while (iter.hasNext()) { MdnsResponse existingResponse = iter.next(); - final String serviceInstanceName = existingResponse.getServiceInstanceName(); if (existingResponse.hasServiceRecord() && existingResponse.getServiceRecord() .getRemainingTTL(clock.elapsedRealtime()) == 0) { - serviceCache.removeService(serviceInstanceName, serviceType, socketKey); - for (int i = 0; i < listeners.size(); i++) { - if (!responseMatchesOptions(existingResponse, listeners.valueAt(i))) { - continue; - } - final MdnsServiceBrowserListener listener = listeners.keyAt(i); - if (serviceInstanceName != null) { - final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse( - existingResponse, serviceTypeLabels); - if (existingResponse.isComplete()) { - sharedLog.log("TTL expired. onServiceRemoved: " + serviceInfo); - listener.onServiceRemoved(serviceInfo); - } - sharedLog.log("TTL expired. onServiceNameRemoved: " + serviceInfo); - listener.onServiceNameRemoved(serviceInfo); - } - } + serviceCache.removeService(existingResponse.getServiceInstanceName(), cacheKey); + notifyRemovedServiceToListeners(existingResponse, "TTL expired"); } } } - private static class QuerySentArguments { private final int transactionId; private final List<String> subTypes = new ArrayList<>(); @@ -672,7 +669,7 @@ public class MdnsServiceTypeClient { private long getMinRemainingTtl(long now) { long minRemainingTtl = Long.MAX_VALUE; - for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) { + for (MdnsResponse response : serviceCache.getCachedServices(cacheKey)) { if (!response.isComplete()) { continue; } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt index b43bcf7315a9f0cefbd1a0dd2f4727a12bca97f5..1b6f12027cb5bd03cb0df0d5396b660325c51d36 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt @@ -19,6 +19,7 @@ package com.android.server.connectivity.mdns import android.os.Build import android.os.Handler import android.os.HandlerThread +import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRunner import java.util.concurrent.CompletableFuture @@ -43,6 +44,8 @@ private const val DEFAULT_TIMEOUT_MS = 2000L @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2) class MdnsServiceCacheTest { private val socketKey = SocketKey(null /* network */, INTERFACE_INDEX) + private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey) + private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey) private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName) private val handler by lazy { Handler(thread.looper) @@ -69,47 +72,36 @@ class MdnsServiceCacheTest { return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS) } - private fun addOrUpdateService( - serviceType: String, - socketKey: SocketKey, - service: MdnsResponse - ): Unit = runningOnHandlerAndReturn { - serviceCache.addOrUpdateService(serviceType, socketKey, service) - } + private fun addOrUpdateService(cacheKey: CacheKey, service: MdnsResponse): Unit = + runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) } - private fun removeService(serviceName: String, serviceType: String, socketKey: SocketKey): - Unit = runningOnHandlerAndReturn { - serviceCache.removeService(serviceName, serviceType, socketKey) } + private fun removeService(serviceName: String, cacheKey: CacheKey): Unit = + runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) } - private fun getService(serviceName: String, serviceType: String, socketKey: SocketKey): - MdnsResponse? = runningOnHandlerAndReturn { - serviceCache.getCachedService(serviceName, serviceType, socketKey) } + private fun getService(serviceName: String, cacheKey: CacheKey): MdnsResponse? = + runningOnHandlerAndReturn { serviceCache.getCachedService(serviceName, cacheKey) } - private fun getServices(serviceType: String, socketKey: SocketKey): List<MdnsResponse> = - runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, socketKey) } + private fun getServices(cacheKey: CacheKey): List<MdnsResponse> = + runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) } @Test fun testAddAndRemoveService() { - addOrUpdateService( - SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) - var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey) + addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) + var response = getService(SERVICE_NAME_1, cacheKey1) assertNotNull(response) assertEquals(SERVICE_NAME_1, response.serviceInstanceName) - removeService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey) - response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey) + removeService(SERVICE_NAME_1, cacheKey1) + response = getService(SERVICE_NAME_1, cacheKey1) assertNull(response) } @Test fun testGetCachedServices_multipleServiceTypes() { - addOrUpdateService( - SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) - addOrUpdateService( - SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1)) - addOrUpdateService( - SERVICE_TYPE_2, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2)) - - val responses1 = getServices(SERVICE_TYPE_1, socketKey) + addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1)) + addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1)) + addOrUpdateService(cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2)) + + val responses1 = getServices(cacheKey1) assertEquals(2, responses1.size) assertTrue(responses1.stream().anyMatch { response -> response.serviceInstanceName == SERVICE_NAME_1 @@ -117,19 +109,19 @@ class MdnsServiceCacheTest { assertTrue(responses1.any { response -> response.serviceInstanceName == SERVICE_NAME_2 }) - val responses2 = getServices(SERVICE_TYPE_2, socketKey) + val responses2 = getServices(cacheKey2) assertEquals(1, responses2.size) assertTrue(responses2.any { response -> response.serviceInstanceName == SERVICE_NAME_2 }) - removeService(SERVICE_NAME_2, SERVICE_TYPE_1, socketKey) - val responses3 = getServices(SERVICE_TYPE_1, socketKey) + removeService(SERVICE_NAME_2, cacheKey1) + val responses3 = getServices(cacheKey1) assertEquals(1, responses3.size) assertTrue(responses3.any { response -> response.serviceInstanceName == SERVICE_NAME_1 }) - val responses4 = getServices(SERVICE_TYPE_2, socketKey) + val responses4 = getServices(cacheKey2) assertEquals(1, responses4.size) assertTrue(responses4.any { response -> response.serviceInstanceName == SERVICE_NAME_2