diff --git a/core/java/android/view/ViewRootImpl.java b/core/java/android/view/ViewRootImpl.java
index bdada11c70739cc9f83258cebc95a9dd3434997a..a37c453d33baf4e202d14bbcae4d7e891fdb7dae 100644
--- a/core/java/android/view/ViewRootImpl.java
+++ b/core/java/android/view/ViewRootImpl.java
@@ -7344,6 +7344,8 @@ public final class ViewRootImpl implements ViewParent,
             final KeyEvent event = (KeyEvent)q.mEvent;
             if (mView.dispatchKeyEventPreIme(event)) {
                 return FINISH_HANDLED;
+            } else if (q.forPreImeOnly()) {
+                return FINISH_NOT_HANDLED;
             }
             return FORWARD;
         }
@@ -9850,6 +9852,7 @@ public final class ViewRootImpl implements ViewParent,
         public static final int FLAG_RESYNTHESIZED = 1 << 4;
         public static final int FLAG_UNHANDLED = 1 << 5;
         public static final int FLAG_MODIFIED_FOR_COMPATIBILITY = 1 << 6;
+        public static final int FLAG_PRE_IME_ONLY = 1 << 7;
 
         public QueuedInputEvent mNext;
 
@@ -9857,6 +9860,13 @@ public final class ViewRootImpl implements ViewParent,
         public InputEventReceiver mReceiver;
         public int mFlags;
 
+        public boolean forPreImeOnly() {
+            if ((mFlags & FLAG_PRE_IME_ONLY) != 0) {
+                return true;
+            }
+            return false;
+        }
+
         public boolean shouldSkipIme() {
             if ((mFlags & FLAG_DELIVER_POST_IME) != 0) {
                 return true;
@@ -9883,6 +9893,7 @@ public final class ViewRootImpl implements ViewParent,
             hasPrevious = flagToString("FINISHED_HANDLED", FLAG_FINISHED_HANDLED, hasPrevious, sb);
             hasPrevious = flagToString("RESYNTHESIZED", FLAG_RESYNTHESIZED, hasPrevious, sb);
             hasPrevious = flagToString("UNHANDLED", FLAG_UNHANDLED, hasPrevious, sb);
+            hasPrevious = flagToString("FLAG_PRE_IME_ONLY", FLAG_PRE_IME_ONLY, hasPrevious, sb);
             if (!hasPrevious) {
                 sb.append("0");
             }
@@ -9939,7 +9950,7 @@ public final class ViewRootImpl implements ViewParent,
     }
 
     @UnsupportedAppUsage
-    void enqueueInputEvent(InputEvent event,
+    QueuedInputEvent enqueueInputEvent(InputEvent event,
             InputEventReceiver receiver, int flags, boolean processImmediately) {
         QueuedInputEvent q = obtainQueuedInputEvent(event, receiver, flags);
 
@@ -9978,6 +9989,7 @@ public final class ViewRootImpl implements ViewParent,
         } else {
             scheduleProcessInputEvents();
         }
+        return q;
     }
 
     private void scheduleProcessInputEvents() {
@@ -12274,29 +12286,45 @@ public final class ViewRootImpl implements ViewParent,
                             + "IWindow:%s Session:%s",
                     mOnBackInvokedDispatcher, mBasePackageName, mWindow, mWindowSession));
         }
-        mOnBackInvokedDispatcher.attachToWindow(mWindowSession, mWindow,
+        mOnBackInvokedDispatcher.attachToWindow(mWindowSession, mWindow, this,
                 mImeBackAnimationController);
     }
 
-    private void sendBackKeyEvent(int action) {
+    /**
+     * Sends {@link KeyEvent#ACTION_DOWN ACTION_DOWN} and {@link KeyEvent#ACTION_UP ACTION_UP}
+     * back key events
+     *
+     * @param preImeOnly whether the back events should be sent to the pre-ime stage only
+     * @return whether the event was handled (i.e. onKeyPreIme consumed it if preImeOnly=true)
+     */
+    public boolean injectBackKeyEvents(boolean preImeOnly) {
+        boolean consumed;
+        try {
+            processingBackKey(true);
+            sendBackKeyEvent(KeyEvent.ACTION_DOWN, preImeOnly);
+            consumed = sendBackKeyEvent(KeyEvent.ACTION_UP, preImeOnly);
+        } finally {
+            processingBackKey(false);
+        }
+        return consumed;
+    }
+
+    private boolean sendBackKeyEvent(int action, boolean preImeOnly) {
         long when = SystemClock.uptimeMillis();
         final KeyEvent ev = new KeyEvent(when, when, action,
                 KeyEvent.KEYCODE_BACK, 0 /* repeat */, 0 /* metaState */,
                 KeyCharacterMap.VIRTUAL_KEYBOARD, 0 /* scancode */,
                 KeyEvent.FLAG_FROM_SYSTEM | KeyEvent.FLAG_VIRTUAL_HARD_KEY,
                 InputDevice.SOURCE_KEYBOARD);
-        enqueueInputEvent(ev, null /* receiver */, 0 /* flags */, true /* processImmediately */);
+        int flags = preImeOnly ? QueuedInputEvent.FLAG_PRE_IME_ONLY : 0;
+        QueuedInputEvent q = enqueueInputEvent(ev, null /* receiver */, flags,
+                true /* processImmediately */);
+        return (q.mFlags & QueuedInputEvent.FLAG_FINISHED_HANDLED) != 0;
     }
 
     private void registerCompatOnBackInvokedCallback() {
         mCompatOnBackInvokedCallback = () -> {
-            try {
-                processingBackKey(true);
-                sendBackKeyEvent(KeyEvent.ACTION_DOWN);
-                sendBackKeyEvent(KeyEvent.ACTION_UP);
-            } finally {
-                processingBackKey(false);
-            }
+            injectBackKeyEvents(/* preImeOnly */ false);
         };
         if (mOnBackInvokedDispatcher.hasImeOnBackInvokedDispatcher()) {
             Log.d(TAG, "Skip registering CompatOnBackInvokedCallback on IME dispatcher");
diff --git a/core/java/android/window/WindowOnBackInvokedDispatcher.java b/core/java/android/window/WindowOnBackInvokedDispatcher.java
index 0ff52f13222df28cc4433bcf9e33bd3360fec161..0fb5e34821780c4e47c55c79cc6bbfd8df1b355b 100644
--- a/core/java/android/window/WindowOnBackInvokedDispatcher.java
+++ b/core/java/android/window/WindowOnBackInvokedDispatcher.java
@@ -37,6 +37,7 @@ import android.view.IWindow;
 import android.view.IWindowSession;
 import android.view.ImeBackAnimationController;
 import android.view.MotionEvent;
+import android.view.ViewRootImpl;
 
 import androidx.annotation.VisibleForTesting;
 
@@ -68,6 +69,7 @@ import java.util.function.Supplier;
 public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
     private IWindowSession mWindowSession;
     private IWindow mWindow;
+    private ViewRootImpl mViewRoot;
     @VisibleForTesting
     public final BackTouchTracker mTouchTracker = new BackTouchTracker();
     @VisibleForTesting
@@ -134,10 +136,12 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
      * is attached a window.
      */
     public void attachToWindow(@NonNull IWindowSession windowSession, @NonNull IWindow window,
+            @Nullable ViewRootImpl viewRoot,
             @Nullable ImeBackAnimationController imeBackAnimationController) {
         synchronized (mLock) {
             mWindowSession = windowSession;
             mWindow = window;
+            mViewRoot = viewRoot;
             mImeBackAnimationController = imeBackAnimationController;
             if (!mAllCallbacks.isEmpty()) {
                 setTopOnBackInvokedCallback(getTopCallback());
@@ -151,6 +155,7 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
             clear();
             mWindow = null;
             mWindowSession = null;
+            mViewRoot = null;
             mImeBackAnimationController = null;
         }
     }
@@ -176,8 +181,6 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
                 return;
             }
             if (callback instanceof ImeOnBackInvokedDispatcher.ImeOnBackInvokedCallback) {
-                // Fall back to compat back key injection if legacy back behaviour should be used.
-                if (!isOnBackInvokedCallbackEnabled()) return;
                 if (callback instanceof ImeOnBackInvokedDispatcher.DefaultImeOnBackAnimationCallback
                         && mImeBackAnimationController != null) {
                     // register ImeBackAnimationController instead to play predictive back animation
@@ -309,7 +312,7 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
             if (callback != null) {
                 int priority = mAllCallbacks.get(callback);
                 final IOnBackInvokedCallback iCallback = new OnBackInvokedCallbackWrapper(
-                        callback, mTouchTracker, mProgressAnimator, mHandler);
+                        callback, mTouchTracker, mProgressAnimator, mHandler, mViewRoot);
                 callbackInfo = new OnBackInvokedCallbackInfo(
                         iCallback,
                         priority,
@@ -399,16 +402,20 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
         private final BackTouchTracker mTouchTracker;
         @NonNull
         private final Handler mHandler;
+        @Nullable
+        private ViewRootImpl mViewRoot;
 
         OnBackInvokedCallbackWrapper(
                 @NonNull OnBackInvokedCallback callback,
                 @NonNull BackTouchTracker touchTracker,
                 @NonNull BackProgressAnimator progressAnimator,
-                @NonNull Handler handler) {
+                @NonNull Handler handler,
+                @Nullable ViewRootImpl viewRoot) {
             mCallback = new WeakReference<>(callback);
             mTouchTracker = touchTracker;
             mProgressAnimator = progressAnimator;
             mHandler = handler;
+            mViewRoot = viewRoot;
         }
 
         @Override
@@ -451,6 +458,7 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
         public void onBackInvoked() throws RemoteException {
             mHandler.post(() -> {
                 mTouchTracker.reset();
+                if (consumedByOnKeyPreIme()) return;
                 boolean isInProgress = mProgressAnimator.isBackAnimationInProgress();
                 mProgressAnimator.reset();
                 // TODO(b/333957271): Re-introduce auto fling progress generation.
@@ -467,6 +475,26 @@ public class WindowOnBackInvokedDispatcher implements OnBackInvokedDispatcher {
             });
         }
 
+        private boolean consumedByOnKeyPreIme() {
+            final OnBackInvokedCallback callback = mCallback.get();
+            if ((callback instanceof ImeBackAnimationController
+                    || callback instanceof ImeOnBackInvokedDispatcher.ImeOnBackInvokedCallback)
+                    && mViewRoot != null && !isOnBackInvokedCallbackEnabled(mViewRoot.mContext)) {
+                // call onKeyPreIme API if the current callback is an IME callback and the app has
+                // not set enableOnBackInvokedCallback="false"
+                boolean consumed = mViewRoot.injectBackKeyEvents(/*preImeOnly*/ true);
+                if (consumed) {
+                    // back event intercepted by app in onKeyPreIme -> cancel the IME animation.
+                    final OnBackAnimationCallback animationCallback = getBackAnimationCallback();
+                    if (animationCallback != null) {
+                        mProgressAnimator.onBackCancelled(animationCallback::onBackCancelled);
+                    }
+                    return true;
+                }
+            }
+            return false;
+        }
+
         @Override
         public void setTriggerBack(boolean triggerBack) throws RemoteException {
             mTouchTracker.setTriggerBack(triggerBack);
diff --git a/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java b/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
index 50d7f59f70e95c06fd3bac75a513d524a91be024..1aada40ab8e9d042ee5ade174f7cbad426ea3200 100644
--- a/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
+++ b/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
@@ -111,7 +111,7 @@ public class WindowOnBackInvokedDispatcherTest {
         doReturn(mApplicationInfo).when(mContext).getApplicationInfo();
 
         mDispatcher = new WindowOnBackInvokedDispatcher(mContext, Looper.getMainLooper());
-        mDispatcher.attachToWindow(mWindowSession, mWindow, mImeBackAnimationController);
+        mDispatcher.attachToWindow(mWindowSession, mWindow, null, mImeBackAnimationController);
     }
 
     private void waitForIdle() {
@@ -454,25 +454,26 @@ public class WindowOnBackInvokedDispatcherTest {
 
     @Test
     public void registerImeCallbacks_onBackInvokedCallbackEnabled() throws RemoteException {
-        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mDefaultImeCallback);
+        verifyImeCallackRegistrations();
+    }
+
+    @Test
+    public void registerImeCallbacks_onBackInvokedCallbackDisabled() throws RemoteException {
+        doReturn(false).when(mApplicationInfo).isOnBackInvokedCallbackEnabled();
+        verifyImeCallackRegistrations();
+    }
+
+    private void verifyImeCallackRegistrations() throws RemoteException {
+        // verify default callback is replaced with ImeBackAnimationController
+        mDispatcher.registerOnBackInvokedCallbackUnchecked(mDefaultImeCallback, PRIORITY_DEFAULT);
         assertCallbacksSize(/* default */ 1, /* overlay */ 0);
         assertSetCallbackInfo();
         assertTopCallback(mImeBackAnimationController);
 
-        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mImeCallback);
+        // verify regular ime callback is successfully registered
+        mDispatcher.registerOnBackInvokedCallbackUnchecked(mImeCallback, PRIORITY_DEFAULT);
         assertCallbacksSize(/* default */ 2, /* overlay */ 0);
         assertSetCallbackInfo();
         assertTopCallback(mImeCallback);
     }
-
-    @Test
-    public void registerImeCallbacks_legacyBack() throws RemoteException {
-        doReturn(false).when(mApplicationInfo).isOnBackInvokedCallbackEnabled();
-
-        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mDefaultImeCallback);
-        assertNoSetCallbackInfo();
-
-        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mImeCallback);
-        assertNoSetCallbackInfo();
-    }
 }
diff --git a/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java b/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
index a39a1a8637dfa9dddeb104fd0a51888ac3ce3ff8..c67d1ec63827924f652daddb271161b52e3679cf 100644
--- a/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/BackNavigationControllerTests.java
@@ -550,7 +550,7 @@ public class BackNavigationControllerTests extends WindowTestsBase {
         }).when(appWindow.mSession).setOnBackInvokedCallbackInfo(eq(appWindow.mClient), any());
 
         addToWindowMap(appWindow, true);
-        dispatcher.attachToWindow(appWindow.mSession, appWindow.mClient, null);
+        dispatcher.attachToWindow(appWindow.mSession, appWindow.mClient, null, null);
 
 
         OnBackInvokedCallback appCallback = createBackCallback(appLatch);