diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/splitscreen/StageCoordinator.java b/libs/WindowManager/Shell/src/com/android/wm/shell/splitscreen/StageCoordinator.java
index c95c9f099f20dda7396b466a4d4ba440103775aa..1c792395d22d97fd5d85d9c1f8357cf796ff9396 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/splitscreen/StageCoordinator.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/splitscreen/StageCoordinator.java
@@ -2244,6 +2244,25 @@ public class StageCoordinator implements SplitLayout.SplitLayoutHandler,
         return SPLIT_POSITION_UNDEFINED;
     }
 
+    /**
+     * Returns the {@link StageType} where {@param token} is being used
+     * {@link SplitScreen#STAGE_TYPE_UNDEFINED} otherwise
+     */
+    @StageType
+    public int getSplitItemStage(@Nullable WindowContainerToken token) {
+        if (token == null) {
+            return STAGE_TYPE_UNDEFINED;
+        }
+
+        if (mMainStage.containsToken(token)) {
+            return STAGE_TYPE_MAIN;
+        } else if (mSideStage.containsToken(token)) {
+            return STAGE_TYPE_SIDE;
+        }
+
+        return STAGE_TYPE_UNDEFINED;
+    }
+
     @Override
     public void setLayoutOffsetTarget(int offsetX, int offsetY, SplitLayout layout) {
         final StageTaskListener topLeftStage =
@@ -2491,7 +2510,16 @@ public class StageCoordinator implements SplitLayout.SplitLayoutHandler,
                 mRecentTasks.ifPresent(
                         recentTasks -> recentTasks.removeSplitPair(triggerTask.taskId));
             }
-            prepareExitSplitScreen(STAGE_TYPE_UNDEFINED, outWCT);
+            @StageType int topStage = STAGE_TYPE_UNDEFINED;
+            if (isSplitScreenVisible()) {
+                // Get the stage where a child exists to keep that stage onTop
+                if (mMainStage.getChildCount() != 0 && mSideStage.getChildCount() == 0) {
+                    topStage = STAGE_TYPE_MAIN;
+                } else if (mSideStage.getChildCount() != 0 && mMainStage.getChildCount() == 0) {
+                    topStage = STAGE_TYPE_SIDE;
+                }
+            }
+            prepareExitSplitScreen(topStage, outWCT);
         }
     }
 
@@ -2908,7 +2936,11 @@ public class StageCoordinator implements SplitLayout.SplitLayoutHandler,
         return SPLIT_POSITION_UNDEFINED;
     }
 
-    /** Synchronize split-screen state with transition and make appropriate preparations. */
+    /**
+     * Synchronize split-screen state with transition and make appropriate preparations.
+     * @param toStage The stage that will not be dismissed. If set to
+     *        {@link SplitScreen#STAGE_TYPE_UNDEFINED} then both stages will be dismissed
+     */
     public void prepareDismissAnimation(@StageType int toStage, @ExitReason int dismissReason,
             @NonNull TransitionInfo info, @NonNull SurfaceControl.Transaction t,
             @NonNull SurfaceControl.Transaction finishT) {
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/transition/DefaultMixedHandler.java b/libs/WindowManager/Shell/src/com/android/wm/shell/transition/DefaultMixedHandler.java
index 4897c4ee7dd6f260d5566bf004ba484a0d50e316..f0bb665f8082fc874d3b074450bef8fbc853afce 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/transition/DefaultMixedHandler.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/transition/DefaultMixedHandler.java
@@ -50,6 +50,7 @@ import com.android.wm.shell.keyguard.KeyguardTransitionHandler;
 import com.android.wm.shell.pip.PipTransitionController;
 import com.android.wm.shell.protolog.ShellProtoLogGroup;
 import com.android.wm.shell.recents.RecentsTransitionHandler;
+import com.android.wm.shell.splitscreen.SplitScreen;
 import com.android.wm.shell.splitscreen.SplitScreenController;
 import com.android.wm.shell.splitscreen.StageCoordinator;
 import com.android.wm.shell.sysui.ShellInit;
@@ -511,8 +512,26 @@ public class DefaultMixedHandler implements Transitions.TransitionHandler,
             // make a new startTransaction because pip's startEnterAnimation "consumes" it so
             // we need a separate one to send over to launcher.
             SurfaceControl.Transaction otherStartT = new SurfaceControl.Transaction();
+            @SplitScreen.StageType int topStageToKeep = STAGE_TYPE_UNDEFINED;
+            if (mSplitHandler.isSplitScreenVisible()) {
+                // The non-going home case, we could be pip-ing one of the split stages and keep
+                // showing the other
+                for (int i = info.getChanges().size() - 1; i >= 0; --i) {
+                    TransitionInfo.Change change = info.getChanges().get(i);
+                    if (change == pipChange) {
+                        // Ignore the change/task that's going into Pip
+                        continue;
+                    }
+                    @SplitScreen.StageType int splitItemStage =
+                            mSplitHandler.getSplitItemStage(change.getLastParent());
+                    if (splitItemStage != STAGE_TYPE_UNDEFINED) {
+                        topStageToKeep = splitItemStage;
+                        break;
+                    }
+                }
+            }
             // Let split update internal state for dismiss.
-            mSplitHandler.prepareDismissAnimation(STAGE_TYPE_UNDEFINED,
+            mSplitHandler.prepareDismissAnimation(topStageToKeep,
                     EXIT_REASON_CHILD_TASK_ENTER_PIP, everythingElse, otherStartT,
                     finishTransaction);