diff --git a/core/java/android/view/ViewRootImpl.java b/core/java/android/view/ViewRootImpl.java
index a37c453d33baf4e202d14bbcae4d7e891fdb7dae..bdada11c70739cc9f83258cebc95a9dd3434997a 100644
--- a/core/java/android/view/ViewRootImpl.java
+++ b/core/java/android/view/ViewRootImpl.java
@@ -7344,8 +7344,6 @@ public final class ViewRootImpl implements ViewParent,
             final KeyEvent event = (KeyEvent)q.mEvent;
             if (mView.dispatchKeyEventPreIme(event)) {
                 return FINISH_HANDLED;
-            } else if (q.forPreImeOnly()) {
-                return FINISH_NOT_HANDLED;
             }
             return FORWARD;
         }
@@ -9852,7 +9850,6 @@ public final class ViewRootImpl implements ViewParent,
         public static final int FLAG_RESYNTHESIZED = 1 << 4;
         public static final int FLAG_UNHANDLED = 1 << 5;
         public static final int FLAG_MODIFIED_FOR_COMPATIBILITY = 1 << 6;
-        public static final int FLAG_PRE_IME_ONLY = 1 << 7;
 
         public QueuedInputEvent mNext;
 
@@ -9860,13 +9857,6 @@ public final class ViewRootImpl implements ViewParent,
         public InputEventReceiver mReceiver;
         public int mFlags;
 
-        public boolean forPreImeOnly() {
-            if ((mFlags & FLAG_PRE_IME_ONLY) != 0) {
-                return true;
-            }
-            return false;
-        }
-
         public boolean shouldSkipIme() {
             if ((mFlags & FLAG_DELIVER_POST_IME) != 0) {
                 return true;
@@ -9893,7 +9883,6 @@ public final class ViewRootImpl implements ViewParent,
             hasPrevious = flagToString("FINISHED_HANDLED", FLAG_FINISHED_HANDLED, hasPrevious, sb);
             hasPrevious = flagToString("RESYNTHESIZED", FLAG_RESYNTHESIZED, hasPrevious, sb);
             hasPrevious = flagToString("UNHANDLED", FLAG_UNHANDLED, hasPrevious, sb);
-            hasPrevious = flagToString("FLAG_PRE_IME_ONLY", FLAG_PRE_IME_ONLY, hasPrevious, sb);
             if (!hasPrevious) {
                 sb.append("0");
             }
@@ -9950,7 +9939,7 @@ public final class ViewRootImpl implements ViewParent,
     }
 
     @UnsupportedAppUsage
-    QueuedInputEvent enqueueInputEvent(InputEvent event,
+    void enqueueInputEvent(InputEvent event,
             InputEventReceiver receiver, int flags, boolean processImmediately) {
         QueuedInputEvent q = obtainQueuedInputEvent(event, receiver, flags);
 
@@ -9989,7 +9978,6 @@ public final class ViewRootImpl implements ViewParent,
         } else {
             scheduleProcessInputEvents();
         }
-        return q;
     }
 
     private void scheduleProcessInputEvents() {
@@ -12286,45 +12274,29 @@ public final class ViewRootImpl implements ViewParent,
                             + "IWindow:%s Session:%s",
                     mOnBackInvokedDispatcher, mBasePackageName, mWindow, mWindowSession));
         }
-        mOnBackInvokedDispatcher.attachToWindow(mWindowSession, mWindow, this,
+        mOnBackInvokedDispatcher.attachToWindow(mWindowSession, mWindow,
                 mImeBackAnimationController);
     }
 
-    /**
-     * Sends {@link KeyEvent#ACTION_DOWN ACTION_DOWN} and {@link KeyEvent#ACTION_UP ACTION_UP}
-     * back key events
-     *
-     * @param preImeOnly whether the back events should be sent to the pre-ime stage only
-     * @return whether the event was handled (i.e. onKeyPreIme consumed it if preImeOnly=true)
-     */
-    public boolean injectBackKeyEvents(boolean preImeOnly) {
-        boolean consumed;
-        try {
-            processingBackKey(true);
-            sendBackKeyEvent(KeyEvent.ACTION_DOWN, preImeOnly);
-            consumed = sendBackKeyEvent(KeyEvent.ACTION_UP, preImeOnly);
-        } finally {
-            processingBackKey(false);
-        }
-        return consumed;
-    }
-
-    private boolean sendBackKeyEvent(int action, boolean preImeOnly) {
+    private void sendBackKeyEvent(int action) {
         long when = SystemClock.uptimeMillis();
         final KeyEvent ev = new KeyEvent(when, when, action,
                 KeyEvent.KEYCODE_BACK, 0 /* repeat */, 0 /* metaState */,
                 KeyCharacterMap.VIRTUAL_KEYBOARD, 0 /* scancode */,
                 KeyEvent.FLAG_FROM_SYSTEM | KeyEvent.FLAG_VIRTUAL_HARD_KEY,
                 InputDevice.SOURCE_KEYBOARD);
-        int flags = preImeOnly ? QueuedInputEvent.FLAG_PRE_IME_ONLY : 0;
-        QueuedInputEvent q = enqueueInputEvent(ev, null /* receiver */, flags,
-                true /* processImmediately */);
-        return (q.mFlags & QueuedInputEvent.FLAG_FINISHED_HANDLED) != 0;
+        enqueueInputEvent(ev, null /* receiver */, 0 /* flags */, true /* processImmediately */);
     }
 
     private void registerCompatOnBackInvokedCallback() {
         mCompatOnBackInvokedCallback = () -> {
-            injectBackKeyEvents(/* preImeOnly */ false);
+            try {
+                processingBackKey(true);
+                sendBackKeyEvent(KeyEvent.ACTION_DOWN);
+                sendBackKeyEvent(KeyEvent.ACTION_UP);
+            } finally {
+                processingBackKey(false);
+            }
         };
         if (mOnBackInvokedDispatcher.hasImeOnBackInvokedDispatcher()) {
             Log.d(TAG, "Skip registering CompatOnBackInvokedCallback on IME dispatcher");
diff --git a/core/java/android/window/WindowOnBackInvokedDispatcher.java b/core/java/android/window/WindowOnBackInvokedDispatcher.java
index 0fb5e34821780c4e47c55c79cc6bbfd8df1b355b..0ff52f13222df28cc4433bcf9e33bd3360fec161 100644
--- a/core/java/android/window/WindowOnBackInvokedDispatcher.java
+++ b/core/java/android/window/WindowOnBackInvokedDispatcher.java
@@ -37,7 +37,6 @@ import android.view.IWindow;
 import android.view.IWindowSession;
 import android.view.ImeBackAnimationController;
 import android.view.MotionEvent;
-import android.view.ViewRootImpl;
 
 import androidx.annotation.VisibleForTesting;
 
@@ -69,7 +68,6 @@ import java.util.function.Supplier;
 public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
     private IWindowSession mWindowSession;
     private IWindow mWindow;
-    private ViewRootImpl mViewRoot;
     @VisibleForTesting
     public final BackTouchTracker mTouchTracker = new BackTouchTracker();
     @VisibleForTesting
@@ -136,12 +134,10 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
      * is attached a window.
      */
     public void attachToWindow(@NonNull IWindowSession windowSession, @NonNull IWindow window,
-            @Nullable ViewRootImpl viewRoot,
             @Nullable ImeBackAnimationController imeBackAnimationController) {
         synchronized (mLock) {
             mWindowSession = windowSession;
             mWindow = window;
-            mViewRoot = viewRoot;
             mImeBackAnimationController = imeBackAnimationController;
             if (!mAllCallbacks.isEmpty()) {
                 setTopOnBackInvokedCallback(getTopCallback());
@@ -155,7 +151,6 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
             clear();
             mWindow = null;
             mWindowSession = null;
-            mViewRoot = null;
             mImeBackAnimationController = null;
         }
     }
@@ -181,6 +176,8 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
                 return;
             }
             if (callback instanceof ImeOnBackInvokedDispatcher.ImeOnBackInvokedCallback) {
+                // Fall back to compat back key injection if legacy back behaviour should be used.
+                if (!isOnBackInvokedCallbackEnabled()) return;
                 if (callback instanceof ImeOnBackInvokedDispatcher.DefaultImeOnBackAnimationCallback
                         && mImeBackAnimationController != null) {
                     // register ImeBackAnimationController instead to play predictive back animation
@@ -312,7 +309,7 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
             if (callback != null) {
                 int priority = mAllCallbacks.get(callback);
                 final IOnBackInvokedCallback iCallback = new OnBackInvokedCallbackWrapper(
-                        callback, mTouchTracker, mProgressAnimator, mHandler, mViewRoot);
+                        callback, mTouchTracker, mProgressAnimator, mHandler);
                 callbackInfo = new OnBackInvokedCallbackInfo(
                         iCallback,
                         priority,
@@ -402,20 +399,16 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
         private final BackTouchTracker mTouchTracker;
         @NonNull
         private final Handler mHandler;
-        @Nullable
-        private ViewRootImpl mViewRoot;
 
         OnBackInvokedCallbackWrapper(
                 @NonNull OnBackInvokedCallback callback,
                 @NonNull BackTouchTracker touchTracker,
                 @NonNull BackProgressAnimator progressAnimator,
-                @NonNull Handler handler,
-                @Nullable ViewRootImpl viewRoot) {
+                @NonNull Handler handler) {
             mCallback = new WeakReference<>(callback);
             mTouchTracker = touchTracker;
             mProgressAnimator = progressAnimator;
             mHandler = handler;
-            mViewRoot = viewRoot;
         }
 
         @Override
@@ -458,7 +451,6 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
         public void onBackInvoked() throws RemoteException {
             mHandler.post(() -> {
                 mTouchTracker.reset();
-                if (consumedByOnKeyPreIme()) return;
                 boolean isInProgress = mProgressAnimator.isBackAnimationInProgress();
                 mProgressAnimator.reset();
                 // TODO(b/333957271): Re-introduce auto fling progress generation.
@@ -475,26 +467,6 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
             });
         }
 
-        private boolean consumedByOnKeyPreIme() {
-            final OnBackInvokedCallback callback = mCallback.get();
-            if ((callback instanceof ImeBackAnimationController
-                    || callback instanceof ImeOnBackInvokedDispatcher.ImeOnBackInvokedCallback)
-                    && mViewRoot != null && !isOnBackInvokedCallbackEnabled(mViewRoot.mContext)) {
-                // call onKeyPreIme API if the current callback is an IME callback and the app has
-                // not set enableOnBackInvokedCallback="false"
-                boolean consumed = mViewRoot.injectBackKeyEvents(/*preImeOnly*/ true);
-                if (consumed) {
-                    // back event intercepted by app in onKeyPreIme -> cancel the IME animation.
-                    final OnBackAnimationCallback animationCallback = getBackAnimationCallback();
-                    if (animationCallback != null) {
-                        mProgressAnimator.onBackCancelled(animationCallback::onBackCancelled);
-                    }
-                    return true;
-                }
-            }
-            return false;
-        }
-
         @Override
         public void setTriggerBack(boolean triggerBack) throws RemoteException {
             mTouchTracker.setTriggerBack(triggerBack);
diff --git a/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java b/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
index 1aada40ab8e9d042ee5ade174f7cbad426ea3200..50d7f59f70e95c06fd3bac75a513d524a91be024 100644
--- a/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
+++ b/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
@@ -111,7 +111,7 @@ public class WindowOnBackInvokedDispatcherTest {
         doReturn(mApplicationInfo).when(mContext).getApplicationInfo();
 
         mDispatcher = new WindowOnBackInvokedDispatcher(mContext, Looper.getMainLooper());
-        mDispatcher.attachToWindow(mWindowSession, mWindow, null, mImeBackAnimationController);
+        mDispatcher.attachToWindow(mWindowSession, mWindow, mImeBackAnimationController);
     }
 
     private void waitForIdle() {
@@ -454,26 +454,25 @@ public class WindowOnBackInvokedDispatcherTest {
 
     @Test
     public void registerImeCallbacks_onBackInvokedCallbackEnabled() throws RemoteException {
-        verifyImeCallackRegistrations();
-    }
-
-    @Test
-    public void registerImeCallbacks_onBackInvokedCallbackDisabled() throws RemoteException {
-        doReturn(false).when(mApplicationInfo).isOnBackInvokedCallbackEnabled();
-        verifyImeCallackRegistrations();
-    }
-
-    private void verifyImeCallackRegistrations() throws RemoteException {
-        // verify default callback is replaced with ImeBackAnimationController
-        mDispatcher.registerOnBackInvokedCallbackUnchecked(mDefaultImeCallback, PRIORITY_DEFAULT);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mDefaultImeCallback);
         assertCallbacksSize(/* default */ 1, /* overlay */ 0);
         assertSetCallbackInfo();
         assertTopCallback(mImeBackAnimationController);
 
-        // verify regular ime callback is successfully registered
-        mDispatcher.registerOnBackInvokedCallbackUnchecked(mImeCallback, PRIORITY_DEFAULT);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mImeCallback);
         assertCallbacksSize(/* default */ 2, /* overlay */ 0);
         assertSetCallbackInfo();
         assertTopCallback(mImeCallback);
     }
+
+    @Test
+    public void registerImeCallbacks_legacyBack() throws RemoteException {
+        doReturn(false).when(mApplicationInfo).isOnBackInvokedCallbackEnabled();
+
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mDefaultImeCallback);
+        assertNoSetCallbackInfo();
+
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mImeCallback);
+        assertNoSetCallbackInfo();
+    }
 }
diff --git a/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java b/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
index c67d1ec63827924f652daddb271161b52e3679cf..a39a1a8637dfa9dddeb104fd0a51888ac3ce3ff8 100644
--- a/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
@@ -550,7 +550,7 @@ public class BackNavigationControllerTests extends WindowTestsBase {
         }).when(appWindow.mSession).setOnBackInvokedCallbackInfo(eq(appWindow.mClient), any());
 
         addToWindowMap(appWindow, true);
-        dispatcher.attachToWindow(appWindow.mSession, appWindow.mClient, null, null);
+        dispatcher.attachToWindow(appWindow.mSession, appWindow.mClient, null);
 
 
         OnBackInvokedCallback appCallback = createBackCallback(appLatch);