diff --git a/services/core/java/com/android/server/inputmethod/InputMethodManagerInternal.java b/services/core/java/com/android/server/inputmethod/InputMethodManagerInternal.java
index 4089a81dfc2009e9905f5ade64ce5ed79ec56ada..548945598a41fda9bc64656b656e5893f0414d1e 100644
--- a/services/core/java/com/android/server/inputmethod/InputMethodManagerInternal.java
+++ b/services/core/java/com/android/server/inputmethod/InputMethodManagerInternal.java
@@ -171,9 +171,11 @@ public abstract class InputMethodManagerInternal {
     public abstract void onImeParentChanged();
 
     /**
-     * Destroys the IME surface.
+     * Destroys the IME surface for the given display.
+     *
+     * @param displayId the display hosting the IME window
      */
-    public abstract void removeImeSurface();
+    public abstract void removeImeSurface(int displayId);
 
     /**
      * Updates the IME visibility, back disposition and show IME picker status for SystemUI.
@@ -302,7 +304,7 @@ public abstract class InputMethodManagerInternal {
                 }
 
                 @Override
-                public void removeImeSurface() {
+                public void removeImeSurface(int displayId) {
                 }
 
                 @Override
diff --git a/services/core/java/com/android/server/inputmethod/InputMethodManagerService.java b/services/core/java/com/android/server/inputmethod/InputMethodManagerService.java
index 16e043cfb64d54644d92be4c901728ac9890b6a0..c3eb5195214100f85859648cef6171cd6bccf39d 100644
--- a/services/core/java/com/android/server/inputmethod/InputMethodManagerService.java
+++ b/services/core/java/com/android/server/inputmethod/InputMethodManagerService.java
@@ -5683,7 +5683,7 @@ public final class InputMethodManagerService extends IInputMethodManager.Stub
         }
 
         @Override
-        public void removeImeSurface() {
+        public void removeImeSurface(int displayId) {
             mHandler.obtainMessage(MSG_REMOVE_IME_SURFACE).sendToTarget();
         }
 
diff --git a/services/core/java/com/android/server/wm/DisplayContent.java b/services/core/java/com/android/server/wm/DisplayContent.java
index ae10ce3690aacce50b27482edc4072876214c78c..ba22763dd289ae7d9d82ee7c726e2919981ab2d3 100644
--- a/services/core/java/com/android/server/wm/DisplayContent.java
+++ b/services/core/java/com/android/server/wm/DisplayContent.java
@@ -7090,7 +7090,7 @@ class DisplayContent extends RootDisplayArea implements WindowManagerPolicy.Disp
         }
 
         @Override
-        public void notifyInsetsControlChanged() {
+        public void notifyInsetsControlChanged(int displayId) {
             final InsetsStateController stateController = getInsetsStateController();
             try {
                 mRemoteInsetsController.insetsControlChanged(stateController.getRawInsetsState(),
diff --git a/services/core/java/com/android/server/wm/InsetsControlTarget.java b/services/core/java/com/android/server/wm/InsetsControlTarget.java
index 8ecbc177896cf269d39e9c1237ca756655992011..b74eb56ebdca7a5912d37023d08e07b9f1534803 100644
--- a/services/core/java/com/android/server/wm/InsetsControlTarget.java
+++ b/services/core/java/com/android/server/wm/InsetsControlTarget.java
@@ -29,8 +29,10 @@ interface InsetsControlTarget {
 
     /**
      * Notifies the control target that the insets control has changed.
+     *
+     * @param displayId the display hosting the window of this target
      */
-    default void notifyInsetsControlChanged() {
+    default void notifyInsetsControlChanged(int displayId) {
     };
 
     /**
diff --git a/services/core/java/com/android/server/wm/InsetsPolicy.java b/services/core/java/com/android/server/wm/InsetsPolicy.java
index 7815679902352b58b32a7763b30642281a429ebb..3c556bf7b12658c10e05075a0c0523d10604d966 100644
--- a/services/core/java/com/android/server/wm/InsetsPolicy.java
+++ b/services/core/java/com/android/server/wm/InsetsPolicy.java
@@ -728,7 +728,7 @@ class InsetsPolicy {
         }
 
         @Override
-        public void notifyInsetsControlChanged() {
+        public void notifyInsetsControlChanged(int displayId) {
             mHandler.post(this);
         }
 
diff --git a/services/core/java/com/android/server/wm/InsetsStateController.java b/services/core/java/com/android/server/wm/InsetsStateController.java
index c4d01291f558c1b02a4cc94cb3eaae607c12379c..6b9fcf411ce18477b71d898566630d5eca005d24 100644
--- a/services/core/java/com/android/server/wm/InsetsStateController.java
+++ b/services/core/java/com/android/server/wm/InsetsStateController.java
@@ -72,7 +72,7 @@ class InsetsStateController {
     };
     private final InsetsControlTarget mEmptyImeControlTarget = new InsetsControlTarget() {
         @Override
-        public void notifyInsetsControlChanged() {
+        public void notifyInsetsControlChanged(int displayId) {
             InsetsSourceControl[] controls = getControlsForDispatch(this);
             if (controls == null) {
                 return;
@@ -80,7 +80,7 @@ class InsetsStateController {
             for (InsetsSourceControl control : controls) {
                 if (control.getType() == WindowInsets.Type.ime()) {
                     mDisplayContent.mWmService.mH.post(() ->
-                            InputMethodManagerInternal.get().removeImeSurface());
+                            InputMethodManagerInternal.get().removeImeSurface(displayId));
                 }
             }
         }
@@ -370,9 +370,10 @@ class InsetsStateController {
                 provider.onSurfaceTransactionApplied();
             }
             final ArraySet<InsetsControlTarget> newControlTargets = new ArraySet<>();
+            int displayId = mDisplayContent.getDisplayId();
             for (int i = mPendingControlChanged.size() - 1; i >= 0; i--) {
                 final InsetsControlTarget controlTarget = mPendingControlChanged.valueAt(i);
-                controlTarget.notifyInsetsControlChanged();
+                controlTarget.notifyInsetsControlChanged(displayId);
                 if (mControlTargetProvidersMap.containsKey(controlTarget)) {
                     // We only collect targets who get controls, not lose controls.
                     newControlTargets.add(controlTarget);
diff --git a/services/core/java/com/android/server/wm/WindowState.java b/services/core/java/com/android/server/wm/WindowState.java
index f5f0dc6d7178c9a5df987d5a013e413fddc4f4f0..56e7c69fe52918034715e3360fafbf654524fb48 100644
--- a/services/core/java/com/android/server/wm/WindowState.java
+++ b/services/core/java/com/android/server/wm/WindowState.java
@@ -3775,7 +3775,7 @@ class WindowState extends WindowContainer<WindowState> implements WindowManagerP
     }
 
     @Override
-    public void notifyInsetsControlChanged() {
+    public void notifyInsetsControlChanged(int displayId) {
         ProtoLog.d(WM_DEBUG_WINDOW_INSETS, "notifyInsetsControlChanged for %s ", this);
         if (mRemoved) {
             return;
diff --git a/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java b/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java
index df4af112c0872e41b13232359935848d47532f77..616a23e7ab5b6be9f796769428399a261e280dda 100644
--- a/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java
+++ b/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java
@@ -294,7 +294,7 @@ class WindowTestsBase extends SystemServiceTestsBase {
      */
     static void suppressInsetsAnimation(InsetsControlTarget target) {
         spyOn(target);
-        Mockito.doNothing().when(target).notifyInsetsControlChanged();
+        Mockito.doNothing().when(target).notifyInsetsControlChanged(anyInt());
     }
 
     @After