diff --git a/services/core/java/com/android/server/display/color/ColorDisplayService.java b/services/core/java/com/android/server/display/color/ColorDisplayService.java
index 03b0cfc3d8448d279ede4eb3ecd6da3dcc2af4b1..e3aa161f001ae95df4be914657b34bc996fdd039 100644
--- a/services/core/java/com/android/server/display/color/ColorDisplayService.java
+++ b/services/core/java/com/android/server/display/color/ColorDisplayService.java
@@ -234,7 +234,9 @@ public final class ColorDisplayService extends SystemService {
         }
     }
 
-    @VisibleForTesting void onUserChanged(int userHandle) {
+    // should be called in handler thread (same thread that started animation)
+    @VisibleForTesting
+    void onUserChanged(int userHandle) {
         final ContentResolver cr = getContext().getContentResolver();
 
         if (mCurrentUser != UserHandle.USER_NULL) {
@@ -473,6 +475,15 @@ public final class ColorDisplayService extends SystemService {
         }
     }
 
+    // should be called in handler thread (same thread that started animation)
+    @VisibleForTesting
+    void cancelAllAnimators() {
+        mNightDisplayTintController.cancelAnimator();
+        mGlobalSaturationTintController.cancelAnimator();
+        mReduceBrightColorsTintController.cancelAnimator();
+        mDisplayWhiteBalanceTintController.cancelAnimator();
+    }
+
     private boolean resetReduceBrightColors() {
         if (mCurrentUser == UserHandle.USER_NULL) {
             return false;
diff --git a/services/tests/displayservicetests/src/com/android/server/display/color/ColorDisplayServiceTest.java b/services/tests/displayservicetests/src/com/android/server/display/color/ColorDisplayServiceTest.java
index c7c09b5deb3591723c931c94c8759a01872bbd8c..ec27f9d220dc0724a6ac6094ac58fd107a5905ca 100644
--- a/services/tests/displayservicetests/src/com/android/server/display/color/ColorDisplayServiceTest.java
+++ b/services/tests/displayservicetests/src/com/android/server/display/color/ColorDisplayServiceTest.java
@@ -44,12 +44,12 @@ import android.provider.Settings.System;
 import android.test.mock.MockContentResolver;
 import android.view.Display;
 
-import androidx.test.InstrumentationRegistry;
+import androidx.test.platform.app.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.internal.R;
 import com.android.internal.util.test.FakeSettingsProvider;
-import com.android.server.LocalServices;
+import com.android.internal.util.test.LocalServiceKeeperRule;
 import com.android.server.SystemService;
 import com.android.server.twilight.TwilightListener;
 import com.android.server.twilight.TwilightManager;
@@ -57,6 +57,7 @@ import com.android.server.twilight.TwilightState;
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mockito;
@@ -90,9 +91,13 @@ public class ColorDisplayServiceTest {
         ColorDisplayManager.COLOR_MODE_BOOSTED,
     };
 
+    @Rule
+    public LocalServiceKeeperRule mLocalServiceKeeperRule = new LocalServiceKeeperRule();
+
     @Before
     public void setUp() {
-        mContext = Mockito.spy(new ContextWrapper(InstrumentationRegistry.getTargetContext()));
+        mContext = Mockito.spy(new ContextWrapper(
+                InstrumentationRegistry.getInstrumentation().getTargetContext()));
         doReturn(mContext).when(mContext).getApplicationContext();
 
         final Resources res = Mockito.spy(mContext.getResources());
@@ -112,43 +117,36 @@ public class ColorDisplayServiceTest {
         doReturn(am).when(mContext).getSystemService(Context.ALARM_SERVICE);
 
         mTwilightManager = new MockTwilightManager();
-        LocalServices.addService(TwilightManager.class, mTwilightManager);
+        mLocalServiceKeeperRule.overrideLocalService(TwilightManager.class, mTwilightManager);
 
         mDisplayTransformManager = Mockito.mock(DisplayTransformManager.class);
         doReturn(true).when(mDisplayTransformManager).needsLinearColorMatrix();
-        LocalServices.addService(DisplayTransformManager.class, mDisplayTransformManager);
+        mLocalServiceKeeperRule.overrideLocalService(
+                DisplayTransformManager.class, mDisplayTransformManager);
 
         mDisplayManagerInternal = Mockito.mock(DisplayManagerInternal.class);
-        LocalServices.removeServiceForTest(DisplayManagerInternal.class);
-        LocalServices.addService(DisplayManagerInternal.class, mDisplayManagerInternal);
+        mLocalServiceKeeperRule.overrideLocalService(
+                DisplayManagerInternal.class, mDisplayManagerInternal);
 
         mCds = new ColorDisplayService(mContext);
         mBinderService = mCds.new BinderService();
-        LocalServices.addService(ColorDisplayService.ColorDisplayServiceInternal.class,
+        mLocalServiceKeeperRule.overrideLocalService(
+                ColorDisplayService.ColorDisplayServiceInternal.class,
                 mCds.new ColorDisplayServiceInternal());
     }
 
     @After
     public void tearDown() {
-        /*
-         * Wait for internal {@link Handler} to finish processing pending messages, so that test
-         * code can safelyremove {@link DisplayTransformManager} mock from {@link LocalServices}.
-         */
-        mCds.mHandler.runWithScissors(() -> { /* nop */ }, /* timeout */ 1000);
+        // synchronously cancel all animations
+        mCds.mHandler.runWithScissors(() -> mCds.cancelAllAnimators(), /* timeout */ 1000);
         mCds = null;
 
-        LocalServices.removeServiceForTest(TwilightManager.class);
         mTwilightManager = null;
 
-        LocalServices.removeServiceForTest(DisplayTransformManager.class);
-
         mUserId = UserHandle.USER_NULL;
         mContext = null;
 
         FakeSettingsProvider.clearSettingsProvider();
-
-        LocalServices.removeServiceForTest(ColorDisplayService.ColorDisplayServiceInternal.class);
-        LocalServices.removeServiceForTest(DisplayManagerInternal.class);
     }
 
     @Test
@@ -1249,10 +1247,10 @@ public class ColorDisplayServiceTest {
     private void startService() {
         Secure.putIntForUser(mContext.getContentResolver(), Secure.USER_SETUP_COMPLETE, 1, mUserId);
 
-        InstrumentationRegistry.getInstrumentation().runOnMainSync(() -> {
-            mCds.onBootPhase(SystemService.PHASE_BOOT_COMPLETED);
-            mCds.onUserChanged(mUserId);
-        });
+        InstrumentationRegistry.getInstrumentation().runOnMainSync(
+                () -> mCds.onBootPhase(SystemService.PHASE_BOOT_COMPLETED));
+        // onUserChanged cancels running animations, and should be called in handler thread
+        mCds.mHandler.runWithScissors(() -> mCds.onUserChanged(mUserId), 1000);
     }
 
     /**