From c9bbe6b59e9d48e690fbbd86734de4af3b94413b Mon Sep 17 00:00:00 2001
From: Neil Fuller <nfuller@google.com>
Date: Wed, 6 Oct 2021 14:51:30 +0100
Subject: [PATCH] Fix SntpClient 2036 issue (2/2)

Fix issue with SntpClient after the end of NTP era 0 (2036).

This is the second of two commits. This commit makes the actual fixes
and makes tests pass.

Before this change SntpClient converted to Unix epoch times too eagerly.
NTP 64-bit timestamps are lossy: they only hold the number of seconds /
factions of seconds in the NTP era and the era is not transmitted. The
existing code assumed the era was always era 0, which ends in 2036.

As explained at https://www.eecis.udel.edu/~mills/y2k.html,
the lossiness of the type is not an issue providing that the maths is
implemented carefully: the NTP timestamps are only ever subtracted from
each other, are always assumed to be in the same or adjacent NTP eras,
and are used to calculate offsets that are applied to client Unix epoch
times.

This commit:

+ Switches to use a dedicated Timestamp64 type, avoiding the use
of the Unix epoch.
+ Switches to use a dedicated Duration64 type for holding the
32-bit signed difference between two Timestamp64 instances.
+ Simplifies the readTimeStamp() and writeTimeStamp() methods.
+ Adds missing validation covered by a TODO. The code was randomizing
the lower bits of the client transmit timestamp, but then not checking
the result as it should, presumably because it was difficult to know
what value was sent. Easily fixed with a dedicated type.
+ Stops randomizing the lower bits of various other timestamps
unnecessarily.
+ Fixes some naming to add clarity.

Bug: 199481251
Test: atest core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
Test: atest core/tests/coretests/src/android/net/sntp/Duration64Test.java
Test: atest core/tests/coretests/src/android/net/SntpClientTest.java
Merged-In: I6d3584f318b0ef6ceab42bb88f20c73b0ad006cb
Change-Id: I6d3584f318b0ef6ceab42bb88f20c73b0ad006cb
---
 core/java/android/net/SntpClient.java         | 176 +++++++++---------
 .../src/android/net/SntpClientTest.java       |  24 ++-
 .../android/net/sntp/PredictableRandom.java   |  34 ++++
 .../src/android/net/sntp/Timestamp64Test.java |  93 +++++++++
 4 files changed, 234 insertions(+), 93 deletions(-)
 create mode 100644 core/tests/coretests/src/android/net/sntp/PredictableRandom.java

diff --git a/core/java/android/net/SntpClient.java b/core/java/android/net/SntpClient.java
index aea11fad7832..0eb4cf3ecadf 100644
--- a/core/java/android/net/SntpClient.java
+++ b/core/java/android/net/SntpClient.java
@@ -17,8 +17,11 @@
 package android.net;
 
 import android.compat.annotation.UnsupportedAppUsage;
+import android.net.sntp.Duration64;
+import android.net.sntp.Timestamp64;
 import android.os.SystemClock;
 import android.util.Log;
+import android.util.Slog;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.TrafficStatsConstants;
@@ -27,10 +30,12 @@ import java.net.DatagramPacket;
 import java.net.DatagramSocket;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
+import java.security.NoSuchAlgorithmException;
+import java.security.SecureRandom;
 import java.time.Duration;
 import java.time.Instant;
-import java.util.Arrays;
 import java.util.Objects;
+import java.util.Random;
 import java.util.function.Supplier;
 
 /**
@@ -65,13 +70,11 @@ public class SntpClient {
     private static final int NTP_STRATUM_DEATH = 0;
     private static final int NTP_STRATUM_MAX = 15;
 
-    // Number of seconds between Jan 1, 1900 and Jan 1, 1970
-    // 70 years plus 17 leap days
-    private static final long OFFSET_1900_TO_1970 = ((365L * 70L) + 17L) * 24L * 60L * 60L;
-
     // The source of the current system clock time, replaceable for testing.
     private final Supplier<Instant> mSystemTimeSupplier;
 
+    private final Random mRandom;
+
     // The last offset calculated from an NTP server response
     private long mClockOffset;
 
@@ -92,12 +95,13 @@ public class SntpClient {
 
     @UnsupportedAppUsage
     public SntpClient() {
-        this(Instant::now);
+        this(Instant::now, defaultRandom());
     }
 
     @VisibleForTesting
-    public SntpClient(Supplier<Instant> systemTimeSupplier) {
+    public SntpClient(Supplier<Instant> systemTimeSupplier, Random random) {
         mSystemTimeSupplier = Objects.requireNonNull(systemTimeSupplier);
+        mRandom = Objects.requireNonNull(random);
     }
 
     /**
@@ -144,10 +148,12 @@ public class SntpClient {
 
             // get current time and write it to the request packet
             final Instant requestTime = mSystemTimeSupplier.get();
-            final long requestTimestamp = requestTime.toEpochMilli();
+            final Timestamp64 requestTimestamp = Timestamp64.fromInstant(requestTime);
 
+            final Timestamp64 randomizedRequestTimestamp =
+                    requestTimestamp.randomizeSubMillis(mRandom);
             final long requestTicks = SystemClock.elapsedRealtime();
-            writeTimeStamp(buffer, TRANSMIT_TIME_OFFSET, requestTimestamp);
+            writeTimeStamp(buffer, TRANSMIT_TIME_OFFSET, randomizedRequestTimestamp);
 
             socket.send(request);
 
@@ -156,23 +162,25 @@ public class SntpClient {
             socket.receive(response);
             final long responseTicks = SystemClock.elapsedRealtime();
             final Instant responseTime = requestTime.plusMillis(responseTicks - requestTicks);
-            final long responseTimestamp = responseTime.toEpochMilli();
+            final Timestamp64 responseTimestamp = Timestamp64.fromInstant(responseTime);
 
             // extract the results
             final byte leap = (byte) ((buffer[0] >> 6) & 0x3);
             final byte mode = (byte) (buffer[0] & 0x7);
             final int stratum = (int) (buffer[1] & 0xff);
-            final long originateTimestamp = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
-            final long receiveTimestamp = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
-            final long transmitTimestamp = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);
-            final long referenceTimestamp = readTimeStamp(buffer, REFERENCE_TIME_OFFSET);
+            final Timestamp64 referenceTimestamp = readTimeStamp(buffer, REFERENCE_TIME_OFFSET);
+            final Timestamp64 originateTimestamp = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
+            final Timestamp64 receiveTimestamp = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
+            final Timestamp64 transmitTimestamp = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);
 
             /* Do validation according to RFC */
-            // TODO: validate originateTime == requestTime.
-            checkValidServerReply(leap, mode, stratum, transmitTimestamp, referenceTimestamp);
+            checkValidServerReply(leap, mode, stratum, transmitTimestamp, referenceTimestamp,
+                    randomizedRequestTimestamp, originateTimestamp);
 
-            long roundTripTimeMillis = responseTicks - requestTicks
-                    - (transmitTimestamp - receiveTimestamp);
+            long totalTransactionDurationMillis = responseTicks - requestTicks;
+            long serverDurationMillis =
+                    Duration64.between(receiveTimestamp, transmitTimestamp).toDuration().toMillis();
+            long roundTripTimeMillis = totalTransactionDurationMillis - serverDurationMillis;
 
             Duration clockOffsetDuration = calculateClockOffset(requestTimestamp,
                     receiveTimestamp, transmitTimestamp, responseTimestamp);
@@ -207,20 +215,24 @@ public class SntpClient {
 
     /** Performs the NTP clock offset calculation. */
     @VisibleForTesting
-    public static Duration calculateClockOffset(long clientRequestTimestamp,
-            long serverReceiveTimestamp, long serverTransmitTimestamp,
-            long clientResponseTimestamp) {
-        // receiveTime = originateTime + transit + skew
-        // responseTime = transmitTime + transit - skew
-        // clockOffset = ((receiveTime - originateTime) + (transmitTime - responseTime))/2
-        //             = ((originateTime + transit + skew - originateTime) +
-        //                (transmitTime - (transmitTime + transit - skew)))/2
-        //             = ((transit + skew) + (transmitTime - transmitTime - transit + skew))/2
-        //             = (transit + skew - transit + skew)/2
-        //             = (2 * skew)/2 = skew
-        long clockOffsetMillis = ((serverReceiveTimestamp - clientRequestTimestamp)
-                + (serverTransmitTimestamp - clientResponseTimestamp)) / 2;
-        return Duration.ofMillis(clockOffsetMillis);
+    public static Duration calculateClockOffset(Timestamp64 clientRequestTimestamp,
+            Timestamp64 serverReceiveTimestamp, Timestamp64 serverTransmitTimestamp,
+            Timestamp64 clientResponseTimestamp) {
+        // According to RFC4330:
+        // t is the system clock offset (the adjustment we are trying to find)
+        // t = ((T2 - T1) + (T3 - T4)) / 2
+        //
+        // Which is:
+        // t = (([server]receiveTimestamp - [client]requestTimestamp)
+        //       + ([server]transmitTimestamp - [client]responseTimestamp)) / 2
+        //
+        // See the NTP spec and tests: the numeric types used are deliberate:
+        // + Duration64.between() uses 64-bit arithmetic (32-bit for the seconds).
+        // + plus() / dividedBy() use Duration, which isn't the double precision floating point
+        //   used in NTPv4, but is good enough.
+        return Duration64.between(clientRequestTimestamp, serverReceiveTimestamp)
+                .plus(Duration64.between(clientResponseTimestamp, serverTransmitTimestamp))
+                .dividedBy(2);
     }
 
     @Deprecated
@@ -270,8 +282,9 @@ public class SntpClient {
     }
 
     private static void checkValidServerReply(
-            byte leap, byte mode, int stratum, long transmitTime, long referenceTime)
-            throws InvalidServerReplyException {
+            byte leap, byte mode, int stratum, Timestamp64 transmitTimestamp,
+            Timestamp64 referenceTimestamp, Timestamp64 randomizedRequestTimestamp,
+            Timestamp64 originateTimestamp) throws InvalidServerReplyException {
         if (leap == NTP_LEAP_NOSYNC) {
             throw new InvalidServerReplyException("unsynchronized server");
         }
@@ -281,73 +294,68 @@ public class SntpClient {
         if ((stratum == NTP_STRATUM_DEATH) || (stratum > NTP_STRATUM_MAX)) {
             throw new InvalidServerReplyException("untrusted stratum: " + stratum);
         }
-        if (transmitTime == 0) {
-            throw new InvalidServerReplyException("zero transmitTime");
+        if (!randomizedRequestTimestamp.equals(originateTimestamp)) {
+            throw new InvalidServerReplyException(
+                    "originateTimestamp != randomizedRequestTimestamp");
+        }
+        if (transmitTimestamp.equals(Timestamp64.ZERO)) {
+            throw new InvalidServerReplyException("zero transmitTimestamp");
         }
-        if (referenceTime == 0) {
-            throw new InvalidServerReplyException("zero reference timestamp");
+        if (referenceTimestamp.equals(Timestamp64.ZERO)) {
+            throw new InvalidServerReplyException("zero referenceTimestamp");
         }
     }
 
     /**
      * Reads an unsigned 32 bit big endian number from the given offset in the buffer.
      */
-    private long read32(byte[] buffer, int offset) {
-        byte b0 = buffer[offset];
-        byte b1 = buffer[offset+1];
-        byte b2 = buffer[offset+2];
-        byte b3 = buffer[offset+3];
-
-        // convert signed bytes to unsigned values
-        int i0 = ((b0 & 0x80) == 0x80 ? (b0 & 0x7F) + 0x80 : b0);
-        int i1 = ((b1 & 0x80) == 0x80 ? (b1 & 0x7F) + 0x80 : b1);
-        int i2 = ((b2 & 0x80) == 0x80 ? (b2 & 0x7F) + 0x80 : b2);
-        int i3 = ((b3 & 0x80) == 0x80 ? (b3 & 0x7F) + 0x80 : b3);
-
-        return ((long)i0 << 24) + ((long)i1 << 16) + ((long)i2 << 8) + (long)i3;
+    private long readUnsigned32(byte[] buffer, int offset) {
+        int i0 = buffer[offset++] & 0xFF;
+        int i1 = buffer[offset++] & 0xFF;
+        int i2 = buffer[offset++] & 0xFF;
+        int i3 = buffer[offset] & 0xFF;
+
+        int bits = (i0 << 24) | (i1 << 16) | (i2 << 8) | i3;
+        return bits & 0xFFFF_FFFFL;
     }
 
     /**
-     * Reads the NTP time stamp at the given offset in the buffer and returns
-     * it as a system time (milliseconds since January 1, 1970).
+     * Reads the NTP time stamp from the given offset in the buffer.
      */
-    private long readTimeStamp(byte[] buffer, int offset) {
-        long seconds = read32(buffer, offset);
-        long fraction = read32(buffer, offset + 4);
-        // Special case: zero means zero.
-        if (seconds == 0 && fraction == 0) {
-            return 0;
-        }
-        return ((seconds - OFFSET_1900_TO_1970) * 1000) + ((fraction * 1000L) / 0x100000000L);
+    private Timestamp64 readTimeStamp(byte[] buffer, int offset) {
+        long seconds = readUnsigned32(buffer, offset);
+        int fractionBits = (int) readUnsigned32(buffer, offset + 4);
+        return Timestamp64.fromComponents(seconds, fractionBits);
     }
 
     /**
-     * Writes system time (milliseconds since January 1, 1970) as an NTP time stamp
-     * at the given offset in the buffer.
+     * Writes the NTP time stamp at the given offset in the buffer.
      */
-    private void writeTimeStamp(byte[] buffer, int offset, long time) {
-        // Special case: zero means zero.
-        if (time == 0) {
-            Arrays.fill(buffer, offset, offset + 8, (byte) 0x00);
-            return;
-        }
-
-        long seconds = time / 1000L;
-        long milliseconds = time - seconds * 1000L;
-        seconds += OFFSET_1900_TO_1970;
-
+    private void writeTimeStamp(byte[] buffer, int offset, Timestamp64 timestamp) {
+        long seconds = timestamp.getEraSeconds();
         // write seconds in big endian format
-        buffer[offset++] = (byte)(seconds >> 24);
-        buffer[offset++] = (byte)(seconds >> 16);
-        buffer[offset++] = (byte)(seconds >> 8);
-        buffer[offset++] = (byte)(seconds >> 0);
+        buffer[offset++] = (byte) (seconds >>> 24);
+        buffer[offset++] = (byte) (seconds >>> 16);
+        buffer[offset++] = (byte) (seconds >>> 8);
+        buffer[offset++] = (byte) (seconds);
 
-        long fraction = milliseconds * 0x100000000L / 1000L;
+        int fractionBits = timestamp.getFractionBits();
         // write fraction in big endian format
-        buffer[offset++] = (byte)(fraction >> 24);
-        buffer[offset++] = (byte)(fraction >> 16);
-        buffer[offset++] = (byte)(fraction >> 8);
-        // low order bits should be random data
-        buffer[offset++] = (byte)(Math.random() * 255.0);
+        buffer[offset++] = (byte) (fractionBits >>> 24);
+        buffer[offset++] = (byte) (fractionBits >>> 16);
+        buffer[offset++] = (byte) (fractionBits >>> 8);
+        buffer[offset] = (byte) (fractionBits);
+    }
+
+    private static Random defaultRandom() {
+        Random random;
+        try {
+            random = SecureRandom.getInstanceStrong();
+        } catch (NoSuchAlgorithmException e) {
+            // This should never happen.
+            Slog.wtf(TAG, "Unable to access SecureRandom", e);
+            random = new Random(System.currentTimeMillis());
+        }
+        return random;
     }
 }
diff --git a/core/tests/coretests/src/android/net/SntpClientTest.java b/core/tests/coretests/src/android/net/SntpClientTest.java
index 178cd028dd4b..b400b9bf41dd 100644
--- a/core/tests/coretests/src/android/net/SntpClientTest.java
+++ b/core/tests/coretests/src/android/net/SntpClientTest.java
@@ -46,6 +46,7 @@ import java.time.Instant;
 import java.time.LocalDateTime;
 import java.time.ZoneOffset;
 import java.util.Arrays;
+import java.util.Random;
 import java.util.function.Supplier;
 
 @RunWith(AndroidJUnit4.class)
@@ -134,6 +135,7 @@ public class SntpClientTest {
     private SntpClient mClient;
     private Network mNetwork;
     private Supplier<Instant> mSystemTimeSupplier;
+    private Random mRandom;
 
     @SuppressWarnings("unchecked")
     @Before
@@ -143,9 +145,13 @@ public class SntpClientTest {
         // A mock network has NETID_UNSET, which allows the test to run, with a loopback server,
         // even w/o external networking.
         mNetwork = mock(Network.class, CALLS_REAL_METHODS);
+        mRandom = mock(Random.class);
 
         mSystemTimeSupplier = mock(Supplier.class);
-        mClient = new SntpClient(mSystemTimeSupplier);
+        // Returning zero means the "randomized" bottom bits of the clients transmit timestamp /
+        // server's originate timestamp will be zeros.
+        when(mRandom.nextInt()).thenReturn(0);
+        mClient = new SntpClient(mSystemTimeSupplier, mRandom);
     }
 
     /** Tests when the client and server are in ERA0. b/199481251. */
@@ -258,14 +264,14 @@ public class SntpClientTest {
             long simulatedClientElapsedTimeMillis = totalElapsedTimeMillis;
 
             // Create some symmetrical timestamps.
-            long clientRequestTimestamp =
-                    clientTime.minusMillis(simulatedClientElapsedTimeMillis / 2).toEpochMilli();
-            long clientResponseTimestamp =
-                    clientTime.plusMillis(simulatedClientElapsedTimeMillis / 2).toEpochMilli();
-            long serverReceiveTimestamp =
-                    serverTime.minusMillis(simulatedServerElapsedTimeMillis / 2).toEpochMilli();
-            long serverTransmitTimestamp =
-                    serverTime.plusMillis(simulatedServerElapsedTimeMillis / 2).toEpochMilli();
+            Timestamp64 clientRequestTimestamp = Timestamp64.fromInstant(
+                    clientTime.minusMillis(simulatedClientElapsedTimeMillis / 2));
+            Timestamp64 clientResponseTimestamp = Timestamp64.fromInstant(
+                    clientTime.plusMillis(simulatedClientElapsedTimeMillis / 2));
+            Timestamp64 serverReceiveTimestamp = Timestamp64.fromInstant(
+                    serverTime.minusMillis(simulatedServerElapsedTimeMillis / 2));
+            Timestamp64 serverTransmitTimestamp = Timestamp64.fromInstant(
+                    serverTime.plusMillis(simulatedServerElapsedTimeMillis / 2));
 
             Duration actualOffset = SntpClient.calculateClockOffset(
                     clientRequestTimestamp, serverReceiveTimestamp,
diff --git a/core/tests/coretests/src/android/net/sntp/PredictableRandom.java b/core/tests/coretests/src/android/net/sntp/PredictableRandom.java
new file mode 100644
index 000000000000..bb2922bf8ce2
--- /dev/null
+++ b/core/tests/coretests/src/android/net/sntp/PredictableRandom.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.sntp;
+
+import java.util.Random;
+
+class PredictableRandom extends Random {
+    private int[] mIntSequence = new int[] { 1 };
+    private int mIntPos = 0;
+
+    public void setIntSequence(int[] intSequence) {
+        this.mIntSequence = intSequence;
+    }
+
+    @Override
+    public int nextInt() {
+        int value = mIntSequence[mIntPos++];
+        mIntPos %= mIntSequence.length;
+        return value;
+    }
+}
diff --git a/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java b/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
index 7e945e5f1cb6..c923812fa2fb 100644
--- a/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
+++ b/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
@@ -24,6 +24,9 @@ import static org.junit.Assert.fail;
 import org.junit.Test;
 
 import java.time.Instant;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Set;
 
 public class Timestamp64Test {
 
@@ -205,6 +208,96 @@ public class Timestamp64Test {
                 actualNanos == expectedNanos || actualNanos == expectedNanos - 1);
     }
 
+    @Test
+    public void testMillisRandomizationConstant() {
+        // Mathematically, we can say that to represent 1000 different values, we need 10 binary
+        // digits (2^10 = 1024). The same is true whether we're dealing with integers or fractions.
+        // Unfortunately, for fractions those 1024 values do not correspond to discrete decimal
+        // values. Discrete millisecond values as fractions (e.g. 0.001 - 0.999) cannot be
+        // represented exactly except where the value can also be represented as some combination of
+        // powers of -2. When we convert back and forth, we truncate, so millisecond decimal
+        // fraction N represented as a binary fraction will always be equal to or lower than N. If
+        // we are truncating correctly it will never be as low as (N-0.001). N -> [N-0.001, N].
+
+        // We need to keep 10 bits to hold millis (inaccurately, since there are numbers that
+        // cannot be represented exactly), leaving us able to randomize the remaining 22 bits of the
+        // fraction part without significantly affecting the number represented.
+        assertEquals(22, Timestamp64.SUB_MILLIS_BITS_TO_RANDOMIZE);
+
+        // Brute force proof that randomization logic will keep the timestamp within the range
+        // [N-0.001, N] where x is in milliseconds.
+        int smallFractionRandomizedLow = 0;
+        int smallFractionRandomizedHigh = 0b00000000_00111111_11111111_11111111;
+        int largeFractionRandomizedLow = 0b11111111_11000000_00000000_00000000;
+        int largeFractionRandomizedHigh = 0b11111111_11111111_11111111_11111111;
+
+        long smallLowNanos = Timestamp64.fromComponents(
+                0, smallFractionRandomizedLow).toInstant(0).getNano();
+        long smallHighNanos = Timestamp64.fromComponents(
+                0, smallFractionRandomizedHigh).toInstant(0).getNano();
+        long smallDelta = smallHighNanos - smallLowNanos;
+        long millisInNanos = 1_000_000_000 / 1_000;
+        assertTrue(smallDelta >= 0 && smallDelta < millisInNanos);
+
+        long largeLowNanos = Timestamp64.fromComponents(
+                0, largeFractionRandomizedLow).toInstant(0).getNano();
+        long largeHighNanos = Timestamp64.fromComponents(
+                0, largeFractionRandomizedHigh).toInstant(0).getNano();
+        long largeDelta = largeHighNanos - largeLowNanos;
+        assertTrue(largeDelta >= 0 && largeDelta < millisInNanos);
+
+        PredictableRandom random = new PredictableRandom();
+        random.setIntSequence(new int[] { 0xFFFF_FFFF });
+        Timestamp64 zero = Timestamp64.fromComponents(0, 0);
+        Timestamp64 zeroWithFractionRandomized = zero.randomizeSubMillis(random);
+        assertEquals(zero.getEraSeconds(), zeroWithFractionRandomized.getEraSeconds());
+        assertEquals(smallFractionRandomizedHigh, zeroWithFractionRandomized.getFractionBits());
+    }
+
+    @Test
+    public void testRandomizeLowestBits() {
+        Random random = new Random(1);
+        {
+            int fractionBits = 0;
+            expectIllegalArgumentException(
+                    () -> Timestamp64.randomizeLowestBits(random, fractionBits, -1));
+            expectIllegalArgumentException(
+                    () -> Timestamp64.randomizeLowestBits(random, fractionBits, 0));
+            expectIllegalArgumentException(
+                    () -> Timestamp64.randomizeLowestBits(random, fractionBits, Integer.SIZE));
+            expectIllegalArgumentException(
+                    () -> Timestamp64.randomizeLowestBits(random, fractionBits, Integer.SIZE + 1));
+        }
+
+        // Check the behavior looks correct from a probabilistic point of view.
+        for (int input : new int[] { 0, 0xFFFFFFFF }) {
+            for (int bitCount = 1; bitCount < Integer.SIZE; bitCount++) {
+                int upperBitMask = 0xFFFFFFFF << bitCount;
+                int expectedUpperBits = input & upperBitMask;
+
+                Set<Integer> values = new HashSet<>();
+                values.add(input);
+
+                int trials = 100;
+                for (int i = 0; i < trials; i++) {
+                    int outputFractionBits =
+                            Timestamp64.randomizeLowestBits(random, input, bitCount);
+
+                    // Record the output value for later analysis.
+                    values.add(outputFractionBits);
+
+                    // Check upper bits did not change.
+                    assertEquals(expectedUpperBits, outputFractionBits & upperBitMask);
+                }
+
+                // It's possible to be more rigorous here, perhaps with a histogram. As bitCount
+                // rises, values.size() quickly trend towards the value of trials + 1. For now, this
+                // mostly just guards against a no-op implementation.
+                assertTrue(bitCount + ":" + values.size(), values.size() > 1);
+            }
+        }
+    }
+
     private static void expectIllegalArgumentException(Runnable r) {
         try {
             r.run();
-- 
GitLab