diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java index 9c01ddaf192556cc03f4f7ece76a46ea683e3757..2ac2b18378bdd58b529e42e94cc630b3592f7e17 100644 --- a/service-t/src/com/android/server/NsdService.java +++ b/service-t/src/com/android/server/NsdService.java @@ -1702,6 +1702,8 @@ public class NsdService extends INsdManager.Stub { mContext, MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL)) .setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut( mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT)) + .setIsKnownAnswerSuppressionEnabled(mDeps.isFeatureEnabled( + mContext, MdnsFeatureFlags.NSD_KNOWN_ANSWER_SUPPRESSION)) .build(); mMdnsSocketClient = new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider, diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java index 0a6d8c130be5ad1d62b23031c0dbef00e068ba68..1ad47a30b7862e9a1d85ccd81c280e02e259f796 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java @@ -41,6 +41,11 @@ public class MdnsFeatureFlags { */ public static final String NSD_LIMIT_LABEL_COUNT = "nsd_limit_label_count"; + /** + * A feature flag to control whether the known-answer suppression should be enabled. + */ + public static final String NSD_KNOWN_ANSWER_SUPPRESSION = "nsd_known_answer_suppression"; + // Flag for offload feature public final boolean mIsMdnsOffloadFeatureEnabled; @@ -53,17 +58,22 @@ public class MdnsFeatureFlags { // Flag for label count limit public final boolean mIsLabelCountLimitEnabled; + // Flag for known-answer suppression + public final boolean mIsKnownAnswerSuppressionEnabled; + /** * The constructor for {@link MdnsFeatureFlags}. */ public MdnsFeatureFlags(boolean isOffloadFeatureEnabled, boolean includeInetAddressRecordsInProbing, boolean isExpiredServicesRemovalEnabled, - boolean isLabelCountLimitEnabled) { + boolean isLabelCountLimitEnabled, + boolean isKnownAnswerSuppressionEnabled) { mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled; mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing; mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled; mIsLabelCountLimitEnabled = isLabelCountLimitEnabled; + mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled; } @@ -79,6 +89,7 @@ public class MdnsFeatureFlags { private boolean mIncludeInetAddressRecordsInProbing; private boolean mIsExpiredServicesRemovalEnabled; private boolean mIsLabelCountLimitEnabled; + private boolean mIsKnownAnswerSuppressionEnabled; /** * The constructor for {@link Builder}. @@ -88,6 +99,7 @@ public class MdnsFeatureFlags { mIncludeInetAddressRecordsInProbing = false; mIsExpiredServicesRemovalEnabled = false; mIsLabelCountLimitEnabled = true; // Default enabled. + mIsKnownAnswerSuppressionEnabled = false; } /** @@ -131,6 +143,16 @@ public class MdnsFeatureFlags { return this; } + /** + * Set whether the known-answer suppression is enabled. + * + * @see #NSD_KNOWN_ANSWER_SUPPRESSION + */ + public Builder setIsKnownAnswerSuppressionEnabled(boolean isKnownAnswerSuppressionEnabled) { + mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled; + return this; + } + /** * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder. */ @@ -138,7 +160,8 @@ public class MdnsFeatureFlags { return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled, mIncludeInetAddressRecordsInProbing, mIsExpiredServicesRemovalEnabled, - mIsLabelCountLimitEnabled); + mIsLabelCountLimitEnabled, + mIsKnownAnswerSuppressionEnabled); } } } diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java index 1909208700c080f772d83ce49756822b91e2cd3d..d46a7b547437e7f00c3c0dd12d081b808cf1a47b 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java @@ -90,6 +90,7 @@ public class MdnsRecordRepository { private final Looper mLooper; @NonNull private final String[] mDeviceHostname; + @NonNull private final MdnsFeatureFlags mMdnsFeatureFlags; public MdnsRecordRepository(@NonNull Looper looper, @NonNull String[] deviceHostname, @@ -502,7 +503,7 @@ public class MdnsRecordRepository { // Add answers from general records addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */, null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now, - answerInfo, additionalAnswerRecords); + answerInfo, additionalAnswerRecords, Collections.emptyList()); // Add answers from each service for (int i = 0; i < mServices.size(); i++) { @@ -510,7 +511,7 @@ public class MdnsRecordRepository { if (registration.exiting || registration.isProbing) continue; if (addReplyFromService(question, registration.allRecords, registration.ptrRecords, registration.srvRecord, registration.txtRecord, replyUnicast, now, - answerInfo, additionalAnswerRecords)) { + answerInfo, additionalAnswerRecords, packet.answers)) { registration.repliedServiceCount++; registration.sentPacketCount++; } @@ -563,6 +564,15 @@ public class MdnsRecordRepository { return new MdnsReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest); } + private boolean isKnownAnswer(MdnsRecord answer, @NonNull List<MdnsRecord> knownAnswerRecords) { + for (MdnsRecord knownAnswer : knownAnswerRecords) { + if (answer.equals(knownAnswer) && knownAnswer.getTtl() > (answer.getTtl() / 2)) { + return true; + } + } + return false; + } + /** * Add answers and additional answers for a question, from a ServiceRegistration. */ @@ -572,7 +582,8 @@ public class MdnsRecordRepository { @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord, @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord, boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo, - @NonNull List<MdnsRecord> additionalAnswerRecords) { + @NonNull List<MdnsRecord> additionalAnswerRecords, + @NonNull List<MdnsRecord> knownAnswerRecords) { boolean hasDnsSdPtrRecordAnswer = false; boolean hasDnsSdSrvRecordAnswer = false; boolean hasFullyOwnedNameMatch = false; @@ -601,6 +612,20 @@ public class MdnsRecordRepository { } hasKnownAnswer = true; + + // RFC6762 7.1. Known-Answer Suppression: + // A Multicast DNS responder MUST NOT answer a Multicast DNS query if + // the answer it would give is already included in the Answer Section + // with an RR TTL at least half the correct value. If the RR TTL of the + // answer as given in the Answer Section is less than half of the true + // RR TTL as known by the Multicast DNS responder, the responder MUST + // send an answer so as to update the querier's cache before the record + // becomes in danger of expiration. + if (mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled + && isKnownAnswer(info.record, knownAnswerRecords)) { + continue; + } + hasDnsSdPtrRecordAnswer |= (servicePtrRecords != null && CollectionUtils.any(servicePtrRecords, r -> info == r)); hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord); @@ -612,8 +637,6 @@ public class MdnsRecordRepository { continue; } - // TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half - answerInfo.add(info); } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt index 85e361d04e637c534d47b267f0e5f63c5ce20d76..196f73ff22e088505d932f85e58ff8d08a349f75 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt @@ -36,6 +36,7 @@ import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertNotNull +import kotlin.test.assertNull import kotlin.test.assertTrue import org.junit.After import org.junit.Before @@ -48,6 +49,13 @@ private const val TEST_SERVICE_ID_3 = 44 private const val TEST_PORT = 12345 private const val TEST_SUBTYPE = "_subtype" private const val TEST_SUBTYPE2 = "_subtype2" +// RFC6762 10. Resource Record TTL Values and Cache Coherency +// The recommended TTL value for Multicast DNS resource records with a host name as the resource +// record's name (e.g., A, AAAA, HINFO) or a host name contained within the resource record's rdata +// (e.g., SRV, reverse mapping PTR record) SHOULD be 120 seconds. The recommended TTL value for +// other Multicast DNS resource records is 75 minutes. +private const val LONG_TTL = 4_500_000L +private const val SHORT_TTL = 120_000L private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local") private val TEST_ADDRESSES = listOf( LinkAddress(parseNumericAddress("192.0.2.111"), 24), @@ -120,7 +128,7 @@ class MdnsRecordRepositoryTest { assertEquals(MdnsServiceRecord(expectedName, 0L /* receiptTimeMillis */, false /* cacheFlush */, - 120_000L /* ttlMillis */, + SHORT_TTL /* ttlMillis */, 0 /* servicePriority */, 0 /* serviceWeight */, TEST_PORT, TEST_HOSTNAME), packet.authorityRecords[0]) @@ -524,9 +532,6 @@ class MdnsRecordRepositoryTest { assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address) assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port) - // TTLs as per RFC6762 10. - val longTtl = 4_500_000L - val shortTtl = 120_000L val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") assertEquals(listOf( @@ -534,7 +539,7 @@ class MdnsRecordRepositoryTest { queriedName, 0L /* receiptTimeMillis */, false /* cacheFlush */, - longTtl, + LONG_TTL, serviceName), ), reply.answers) @@ -543,13 +548,13 @@ class MdnsRecordRepositoryTest { serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, - longTtl, + LONG_TTL, listOf() /* entries */), MdnsServiceRecord( serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, - shortTtl, + SHORT_TTL, 0 /* servicePriority */, 0 /* serviceWeight */, TEST_PORT, @@ -558,32 +563,32 @@ class MdnsRecordRepositoryTest { TEST_HOSTNAME, 0L /* receiptTimeMillis */, true /* cacheFlush */, - shortTtl, + SHORT_TTL, TEST_ADDRESSES[0].address), MdnsInetAddressRecord( TEST_HOSTNAME, 0L /* receiptTimeMillis */, true /* cacheFlush */, - shortTtl, + SHORT_TTL, TEST_ADDRESSES[1].address), MdnsInetAddressRecord( TEST_HOSTNAME, 0L /* receiptTimeMillis */, true /* cacheFlush */, - shortTtl, + SHORT_TTL, TEST_ADDRESSES[2].address), MdnsNsecRecord( serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, - longTtl, + LONG_TTL, serviceName /* nextDomain */, intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)), MdnsNsecRecord( TEST_HOSTNAME, 0L /* receiptTimeMillis */, true /* cacheFlush */, - shortTtl, + SHORT_TTL, TEST_HOSTNAME /* nextDomain */, intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)), ), reply.additionalAnswers) @@ -760,7 +765,7 @@ class MdnsRecordRepositoryTest { expectedName, 0L /* receiptTimeMillis */, false /* cacheFlush */, - 120_000L /* ttlMillis */, + SHORT_TTL /* ttlMillis */, 0 /* servicePriority */, 0 /* serviceWeight */, TEST_PORT, @@ -769,24 +774,290 @@ class MdnsRecordRepositoryTest { TEST_HOSTNAME, 0L /* receiptTimeMillis */, false /* cacheFlush */, - 120_000L /* ttlMillis */, + SHORT_TTL /* ttlMillis */, TEST_ADDRESSES[0].address), MdnsInetAddressRecord( TEST_HOSTNAME, 0L /* receiptTimeMillis */, false /* cacheFlush */, - 120_000L /* ttlMillis */, + SHORT_TTL /* ttlMillis */, TEST_ADDRESSES[1].address), MdnsInetAddressRecord( TEST_HOSTNAME, 0L /* receiptTimeMillis */, false /* cacheFlush */, - 120_000L /* ttlMillis */, + SHORT_TTL /* ttlMillis */, TEST_ADDRESSES[2].address) ), packet.authorityRecords) assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices()) } + + private fun doGetReplyWithAnswersTest( + questions: List<MdnsRecord>, + knownAnswers: List<MdnsRecord>, + replyAnswers: List<MdnsRecord>, + additionalAnswers: List<MdnsRecord>, + expectReply: Boolean + ) { + val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, + MdnsFeatureFlags.newBuilder().setIsKnownAnswerSuppressionEnabled(true).build()) + repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1) + val query = MdnsPacket(0 /* flags */, questions, knownAnswers, + listOf() /* authorityRecords */, listOf() /* additionalRecords */) + val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353) + val reply = repository.getReply(query, src) + + if (!expectReply) { + assertNull(reply) + return + } + + assertNotNull(reply) + // Source address is IPv4 + assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address) + assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port) + assertEquals(replyAnswers, reply.answers) + assertEquals(additionalAnswers, reply.additionalAnswers) + } + + @Test + fun testGetReply_HasAnswers() { + val queriedName = arrayOf("_testservice", "_tcp", "local") + val questions = listOf(MdnsPointerRecord(queriedName, false /* isUnicast */)) + val knownAnswers = listOf(MdnsPointerRecord( + arrayOf("_testservice", "_tcp", "local"), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + LONG_TTL, + arrayOf("MyTestService", "_testservice", "_tcp", "local"))) + doGetReplyWithAnswersTest(questions, knownAnswers, listOf() /* replyAnswers */, + listOf() /* additionalAnswers */, false /* expectReply */) + } + + @Test + fun testGetReply_HasAnswers_TtlLessThanHalf() { + val queriedName = arrayOf("_testservice", "_tcp", "local") + val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + val questions = listOf(MdnsPointerRecord(queriedName, false /* isUnicast */)) + val knownAnswers = listOf(MdnsPointerRecord( + arrayOf("_testservice", "_tcp", "local"), + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + (LONG_TTL / 2 - 1000L), + arrayOf("MyTestService", "_testservice", "_tcp", "local"))) + val replyAnswers = listOf(MdnsPointerRecord( + queriedName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + LONG_TTL, + serviceName)) + val additionalAnswers = listOf( + MdnsTextRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + LONG_TTL, + listOf() /* entries */), + MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_HOSTNAME), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[0].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[1].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[2].address), + MdnsNsecRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + LONG_TTL, + serviceName /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)), + MdnsNsecRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_HOSTNAME /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA))) + doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers, + true /* expectReply */) + } + + @Test + fun testGetReply_HasAnotherAnswer() { + val queriedName = arrayOf("_testservice", "_tcp", "local") + val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + val questions = listOf(MdnsPointerRecord(queriedName, false /* isUnicast */)) + val knownAnswers = listOf(MdnsPointerRecord( + queriedName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + LONG_TTL, + arrayOf("MyOtherTestService", "_testservice", "_tcp", "local"))) + val replyAnswers = listOf(MdnsPointerRecord( + queriedName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + LONG_TTL, + serviceName)) + val additionalAnswers = listOf( + MdnsTextRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + LONG_TTL, + listOf() /* entries */), + MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_HOSTNAME), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[0].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[1].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[2].address), + MdnsNsecRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + LONG_TTL, + serviceName /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)), + MdnsNsecRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_HOSTNAME /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA))) + doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers, + true /* expectReply */) + } + + @Test + fun testGetReply_HasAnswers_MultiQuestions() { + val queriedName = arrayOf("_testservice", "_tcp", "local") + val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + val questions = listOf( + MdnsPointerRecord(queriedName, false /* isUnicast */), + MdnsServiceRecord(serviceName, false /* isUnicast */)) + val knownAnswers = listOf(MdnsPointerRecord( + queriedName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + LONG_TTL - 1000L, + serviceName)) + val replyAnswers = listOf(MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + SHORT_TTL /* ttlMillis */, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_HOSTNAME)) + val additionalAnswers = listOf( + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[0].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[1].address), + MdnsInetAddressRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_ADDRESSES[2].address), + MdnsNsecRecord( + serviceName, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + LONG_TTL, + serviceName /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_SRV)), + MdnsNsecRecord( + TEST_HOSTNAME, + 0L /* receiptTimeMillis */, + true /* cacheFlush */, + SHORT_TTL, + TEST_HOSTNAME /* nextDomain */, + intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA))) + doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers, + true /* expectReply */) + } + + @Test + fun testGetReply_HasAnswers_MultiQuestions_NoReply() { + val queriedName = arrayOf("_testservice", "_tcp", "local") + val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local") + val questions = listOf( + MdnsPointerRecord(queriedName, false /* isUnicast */), + MdnsServiceRecord(serviceName, false /* isUnicast */)) + val knownAnswers = listOf( + MdnsPointerRecord( + queriedName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + LONG_TTL - 1000L, + serviceName), + MdnsServiceRecord( + serviceName, + 0L /* receiptTimeMillis */, + false /* cacheFlush */, + SHORT_TTL - 15_000L, + 0 /* servicePriority */, + 0 /* serviceWeight */, + TEST_PORT, + TEST_HOSTNAME)) + doGetReplyWithAnswersTest(questions, knownAnswers, listOf() /* replyAnswers */, + listOf() /* additionalAnswers */, false /* expectReply */) + } } private fun MdnsRecordRepository.initWithService(