diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
index e538e093e60f567e17e2b069ce2c5799fa42e725..2944bd9f9a8e5c069a1862092498eb23f30d40c7 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
@@ -174,8 +174,8 @@ private fun <T> computeValue(
     lerp: (T, T, Float) -> T,
     canOverflow: Boolean,
 ): T {
-    val state = layoutImpl.state.transitionState
-    if (state !is TransitionState.Transition || !layoutImpl.isTransitionReady(state)) {
+    val transition = layoutImpl.state.currentTransition
+    if (transition == null || !layoutImpl.isTransitionReady(transition)) {
         return sharedValue.value
     }
 
@@ -191,10 +191,11 @@ private fun <T> computeValue(
         return value as Element.SharedValue<T>
     }
 
-    val fromValue = sceneValue(state.fromScene)
-    val toValue = sceneValue(state.toScene)
+    val fromValue = sceneValue(transition.fromScene)
+    val toValue = sceneValue(transition.toScene)
     return if (fromValue != null && toValue != null) {
-        val progress = if (canOverflow) state.progress else state.progress.coerceIn(0f, 1f)
+        val progress =
+            if (canOverflow) transition.progress else transition.progress.coerceIn(0f, 1f)
         lerp(fromValue.value, toValue.value, progress)
     } else if (fromValue != null) {
         fromValue.value
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
index de69c37d4630021257206400c96ec0cb73a3e766..ba6d00e3b7f5890aa02f8a73125d531da81f1baf 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
@@ -28,11 +28,11 @@ import kotlinx.coroutines.launch
  * the currently running transition, if there is one.
  */
 internal fun CoroutineScope.animateToScene(
-    layoutImpl: SceneTransitionLayoutImpl,
+    layoutState: SceneTransitionLayoutStateImpl,
     target: SceneKey,
 ) {
-    val state = layoutImpl.state.transitionState
-    if (state.currentScene == target) {
+    val transitionState = layoutState.transitionState
+    if (transitionState.currentScene == target) {
         // This can happen in 3 different situations, for which there isn't anything else to do:
         //  1. There is no ongoing transition and [target] is already the current scene.
         //  2. The user is swiping to [target] from another scene and released their pointer such
@@ -44,44 +44,47 @@ internal fun CoroutineScope.animateToScene(
         return
     }
 
-    when (state) {
-        is TransitionState.Idle -> animate(layoutImpl, target)
+    when (transitionState) {
+        is TransitionState.Idle -> animate(layoutState, target)
         is TransitionState.Transition -> {
             // A transition is currently running: first check whether `transition.toScene` or
             // `transition.fromScene` is the same as our target scene, in which case the transition
             // can be accelerated or reversed to end up in the target state.
 
-            if (state.toScene == target) {
+            if (transitionState.toScene == target) {
                 // The user is currently swiping to [target] but didn't release their pointer yet:
                 // animate the progress to `1`.
 
-                check(state.fromScene == state.currentScene)
-                val progress = state.progress
+                check(transitionState.fromScene == transitionState.currentScene)
+                val progress = transitionState.progress
                 if ((1f - progress).absoluteValue < ProgressVisibilityThreshold) {
-                    // The transition is already finished (progress ~= 1): no need to animate.
-                    layoutImpl.state.transitionState = TransitionState.Idle(state.currentScene)
+                    // The transition is already finished (progress ~= 1): no need to animate. We
+                    // finish the current transition early to make sure that the current state
+                    // change is committed.
+                    layoutState.finishTransition(transitionState, transitionState.currentScene)
                 } else {
                     // The transition is in progress: start the canned animation at the same
                     // progress as it was in.
                     // TODO(b/290184746): Also take the current velocity into account.
-                    animate(layoutImpl, target, startProgress = progress)
+                    animate(layoutState, target, startProgress = progress)
                 }
 
                 return
             }
 
-            if (state.fromScene == target) {
+            if (transitionState.fromScene == target) {
                 // There is a transition from [target] to another scene: simply animate the same
                 // transition progress to `0`.
 
-                check(state.toScene == state.currentScene)
-                val progress = state.progress
+                check(transitionState.toScene == transitionState.currentScene)
+                val progress = transitionState.progress
                 if (progress.absoluteValue < ProgressVisibilityThreshold) {
-                    // The transition is at progress ~= 0: no need to animate.
-                    layoutImpl.state.transitionState = TransitionState.Idle(state.currentScene)
+                    // The transition is at progress ~= 0: no need to animate.We finish the current
+                    // transition early to make sure that the current state change is committed.
+                    layoutState.finishTransition(transitionState, transitionState.currentScene)
                 } else {
                     // TODO(b/290184746): Also take the current velocity into account.
-                    animate(layoutImpl, target, startProgress = progress, reversed = true)
+                    animate(layoutState, target, startProgress = progress, reversed = true)
                 }
 
                 return
@@ -89,27 +92,22 @@ internal fun CoroutineScope.animateToScene(
 
             // Generic interruption; the current transition is neither from or to [target].
             // TODO(b/290930950): Better handle interruptions here.
-            animate(layoutImpl, target)
+            animate(layoutState, target)
         }
     }
 }
 
 private fun CoroutineScope.animate(
-    layoutImpl: SceneTransitionLayoutImpl,
+    layoutState: SceneTransitionLayoutStateImpl,
     target: SceneKey,
     startProgress: Float = 0f,
     reversed: Boolean = false,
 ) {
-    val fromScene = layoutImpl.state.transitionState.currentScene
+    val fromScene = layoutState.transitionState.currentScene
     val isUserInput =
-        (layoutImpl.state.transitionState as? TransitionState.Transition)?.isInitiatedByUserInput
+        (layoutState.transitionState as? TransitionState.Transition)?.isInitiatedByUserInput
             ?: false
 
-    val animationSpec = layoutImpl.transitions.transitionSpec(fromScene, target).spec
-    val visibilityThreshold =
-        (animationSpec as? SpringSpec)?.visibilityThreshold ?: ProgressVisibilityThreshold
-    val animatable = Animatable(startProgress, visibilityThreshold = visibilityThreshold)
-
     val targetProgress = if (reversed) 0f else 1f
     val transition =
         if (reversed) {
@@ -119,7 +117,6 @@ private fun CoroutineScope.animate(
                 currentScene = target,
                 isInitiatedByUserInput = isUserInput,
                 isUserInputOngoing = false,
-                animatable = animatable,
             )
         } else {
             OneOffTransition(
@@ -128,21 +125,27 @@ private fun CoroutineScope.animate(
                 currentScene = target,
                 isInitiatedByUserInput = isUserInput,
                 isUserInputOngoing = false,
-                animatable = animatable,
             )
         }
 
-    // Change the current layout state to use this new transition.
-    layoutImpl.state.transitionState = transition
+    // Change the current layout state to start this new transition. This will compute the
+    // TransformationSpec associated to this transition, which we need to initialize the Animatable
+    // that will actually animate it.
+    layoutState.startTransition(transition)
+
+    // The transformation now contains the spec that we should use to instantiate the Animatable.
+    val animationSpec = layoutState.transformationSpec.progressSpec
+    val visibilityThreshold =
+        (animationSpec as? SpringSpec)?.visibilityThreshold ?: ProgressVisibilityThreshold
+    val animatable =
+        Animatable(startProgress, visibilityThreshold = visibilityThreshold).also {
+            transition.animatable = it
+        }
 
     // Animate the progress to its target value.
     launch {
         animatable.animateTo(targetProgress, animationSpec)
-
-        // Unless some other external state change happened, the state should now be idle.
-        if (layoutImpl.state.transitionState == transition) {
-            layoutImpl.state.transitionState = TransitionState.Idle(target)
-        }
+        layoutState.finishTransition(transition, target)
     }
 }
 
@@ -152,8 +155,16 @@ private class OneOffTransition(
     override val currentScene: SceneKey,
     override val isInitiatedByUserInput: Boolean,
     override val isUserInputOngoing: Boolean,
-    private val animatable: Animatable<Float, AnimationVector1D>,
 ) : TransitionState.Transition(fromScene, toScene) {
+    /**
+     * The animatable used to animate this transition.
+     *
+     * Note: This is lateinit because we need to first create this Transition object so that
+     * [SceneTransitionLayoutState] can compute the transformations and animation spec associated to
+     * it, which is need to initialize this Animatable.
+     */
+    lateinit var animatable: Animatable<Float, AnimationVector1D>
+
     override val progress: Float
         get() = animatable.value
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index 431a8aef6d3daeb2e4d0fcb4254f39b3b74cd75c..5dc1079e8b568383131185aef57590be756784a6 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -181,15 +181,11 @@ private data class ElementModifier(
 }
 
 internal class ElementNode(
-    layoutImpl: SceneTransitionLayoutImpl,
-    scene: Scene,
-    element: Element,
-    sceneValues: Element.TargetValues,
+    private var layoutImpl: SceneTransitionLayoutImpl,
+    private var scene: Scene,
+    private var element: Element,
+    private var sceneValues: Element.TargetValues,
 ) : Modifier.Node(), DrawModifierNode {
-    private var layoutImpl: SceneTransitionLayoutImpl = layoutImpl
-    private var scene: Scene = scene
-    private var element: Element = element
-    private var sceneValues: Element.TargetValues = sceneValues
 
     override fun onAttach() {
         super.onAttach()
@@ -283,26 +279,27 @@ private fun shouldDrawElement(
     scene: Scene,
     element: Element,
 ): Boolean {
-    val state = layoutImpl.state.transitionState
+    val transition = layoutImpl.state.currentTransition
 
     // Always draw the element if there is no ongoing transition or if the element is not shared.
     if (
-        state !is TransitionState.Transition ||
-            !layoutImpl.isTransitionReady(state) ||
-            state.fromScene !in element.sceneValues ||
-            state.toScene !in element.sceneValues
+        transition == null ||
+            !layoutImpl.isTransitionReady(transition) ||
+            transition.fromScene !in element.sceneValues ||
+            transition.toScene !in element.sceneValues
     ) {
         return true
     }
 
-    val sharedTransformation = sharedElementTransformation(layoutImpl, state, element.key)
+    val sharedTransformation =
+        sharedElementTransformation(layoutImpl.state, transition, element.key)
     if (sharedTransformation?.enabled == false) {
         return true
     }
 
     return shouldDrawOrComposeSharedElement(
         layoutImpl,
-        state,
+        transition,
         scene.key,
         element.key,
         sharedTransformation,
@@ -331,21 +328,21 @@ internal fun shouldDrawOrComposeSharedElement(
 }
 
 private fun isSharedElementEnabled(
-    layoutImpl: SceneTransitionLayoutImpl,
+    layoutState: SceneTransitionLayoutStateImpl,
     transition: TransitionState.Transition,
     element: ElementKey,
 ): Boolean {
-    return sharedElementTransformation(layoutImpl, transition, element)?.enabled ?: true
+    return sharedElementTransformation(layoutState, transition, element)?.enabled ?: true
 }
 
 internal fun sharedElementTransformation(
-    layoutImpl: SceneTransitionLayoutImpl,
+    layoutState: SceneTransitionLayoutStateImpl,
     transition: TransitionState.Transition,
     element: ElementKey,
 ): SharedElementTransformation? {
-    val spec = layoutImpl.transitions.transitionSpec(transition.fromScene, transition.toScene)
-    val sharedInFromScene = spec.transformations(element, transition.fromScene).shared
-    val sharedInToScene = spec.transformations(element, transition.toScene).shared
+    val transformationSpec = layoutState.transformationSpec
+    val sharedInFromScene = transformationSpec.transformations(element, transition.fromScene).shared
+    val sharedInToScene = transformationSpec.transformations(element, transition.toScene).shared
 
     // The sharedElement() transformation must either be null or be the same in both scenes.
     if (sharedInFromScene != sharedInToScene) {
@@ -371,13 +368,9 @@ private fun isElementOpaque(
     scene: Scene,
     sceneValues: Element.TargetValues,
 ): Boolean {
-    val state = layoutImpl.state.transitionState
-
-    if (state !is TransitionState.Transition) {
-        return true
-    }
+    val transition = layoutImpl.state.currentTransition ?: return true
 
-    if (!layoutImpl.isTransitionReady(state)) {
+    if (!layoutImpl.isTransitionReady(transition)) {
         val lastValue =
             sceneValues.lastValues.alpha.takeIf { it != Element.AlphaUnspecified }
                 ?: element.lastSharedValues.alpha.takeIf { it != Element.AlphaUnspecified } ?: 1f
@@ -385,8 +378,8 @@ private fun isElementOpaque(
         return lastValue == 1f
     }
 
-    val fromScene = state.fromScene
-    val toScene = state.toScene
+    val fromScene = transition.fromScene
+    val toScene = transition.toScene
     val fromValues = element.sceneValues[fromScene]
     val toValues = element.sceneValues[toScene]
 
@@ -395,14 +388,11 @@ private fun isElementOpaque(
     }
 
     val isSharedElement = fromValues != null && toValues != null
-    if (isSharedElement && isSharedElementEnabled(layoutImpl, state, element.key)) {
+    if (isSharedElement && isSharedElementEnabled(layoutImpl.state, transition, element.key)) {
         return true
     }
 
-    return layoutImpl.transitions
-        .transitionSpec(fromScene, toScene)
-        .transformations(element.key, scene.key)
-        .alpha == null
+    return layoutImpl.state.transformationSpec.transformations(element.key, scene.key).alpha == null
 }
 
 /**
@@ -607,24 +597,22 @@ private inline fun <T> computeValue(
     lastValue: () -> T,
     lerp: (T, T, Float) -> T,
 ): T {
-    val state = layoutImpl.state.transitionState
-
-    // There is no ongoing transition.
-    if (state !is TransitionState.Transition) {
-        // Even if this element SceneTransitionLayout is not animated, the layout itself might be
-        // animated (e.g. by another parent SceneTransitionLayout), in which case this element still
-        // need to participate in the layout phase.
-        return currentValue()
-    }
+    val transition =
+        layoutImpl.state.currentTransition
+        // There is no ongoing transition. Even if this element SceneTransitionLayout is not
+        // animated, the layout itself might be animated (e.g. by another parent
+        // SceneTransitionLayout), in which case this element still need to participate in the
+        // layout phase.
+        ?: return currentValue()
 
     // A transition was started but it's not ready yet (not all elements have been composed/laid
     // out yet). Use the last value that was set, to make sure elements don't unexpectedly jump.
-    if (!layoutImpl.isTransitionReady(state)) {
+    if (!layoutImpl.isTransitionReady(transition)) {
         return lastValue()
     }
 
-    val fromScene = state.fromScene
-    val toScene = state.toScene
+    val fromScene = transition.fromScene
+    val toScene = transition.toScene
     val fromValues = element.sceneValues[fromScene]
     val toValues = element.sceneValues[toScene]
 
@@ -638,21 +626,17 @@ private inline fun <T> computeValue(
     // TODO(b/290184746): Support non linear shared paths as well as a way to make sure that shared
     // elements follow the finger direction.
     val isSharedElement = fromValues != null && toValues != null
-    if (isSharedElement && isSharedElementEnabled(layoutImpl, state, element.key)) {
+    if (isSharedElement && isSharedElementEnabled(layoutImpl.state, transition, element.key)) {
         val start = sceneValue(fromValues!!)
         val end = sceneValue(toValues!!)
 
         // Make sure we don't read progress if values are the same and we don't need to interpolate,
         // so we don't invalidate the phase where this is read.
-        return if (start == end) start else lerp(start, end, state.progress)
+        return if (start == end) start else lerp(start, end, transition.progress)
     }
 
     val transformation =
-        transformation(
-            layoutImpl.transitions
-                .transitionSpec(fromScene, toScene)
-                .transformations(element.key, scene.key)
-        )
+        transformation(layoutImpl.state.transformationSpec.transformations(element.key, scene.key))
         // If there is no transformation explicitly associated to this element value, let's use
         // the value given by the system (like the current position and size given by the layout
         // pass).
@@ -675,7 +659,7 @@ private inline fun <T> computeValue(
             scene,
             element,
             sceneValues,
-            state,
+            transition,
             idleValue,
         )
 
@@ -685,7 +669,7 @@ private inline fun <T> computeValue(
         return targetValue
     }
 
-    val progress = state.progress
+    val progress = transition.progress
     // TODO(b/290184746): Make sure that we don't overflow transformations associated to a range.
     val rangeProgress = transformation.range?.progress(progress) ?: progress
 
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
index 7029da2edb0d5e05b34217092344ccb4fb6327c7..306f27626e196b03af3d1e0c882b738ed193e5c2 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
@@ -120,17 +120,13 @@ private fun shouldComposeMovableElement(
     scene: SceneKey,
     element: Element,
 ): Boolean {
-    val transitionState = layoutImpl.state.transitionState
-
-    // If we are idle, there is only one [scene] that is composed so we can compose our movable
-    // content here.
-    if (transitionState is TransitionState.Idle) {
-        check(transitionState.currentScene == scene)
-        return true
-    }
-
-    val fromScene = (transitionState as TransitionState.Transition).fromScene
-    val toScene = transitionState.toScene
+    val transition =
+        layoutImpl.state.currentTransition
+        // If we are idle, there is only one [scene] that is composed so we can compose our
+        // movable content here.
+        ?: return true
+    val fromScene = transition.fromScene
+    val toScene = transition.toScene
 
     val fromReady = layoutImpl.isSceneReady(fromScene)
     val toReady = layoutImpl.isSceneReady(toScene)
@@ -181,10 +177,10 @@ private fun shouldComposeMovableElement(
 
     return shouldDrawOrComposeSharedElement(
         layoutImpl,
-        transitionState,
+        transition,
         scene,
         element.key,
-        sharedElementTransformation(layoutImpl, transitionState, element.key),
+        sharedElementTransformation(layoutImpl.state, transition, element.key),
     )
 }
 
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/NestedScrollToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/NestedScrollToScene.kt
index 32025b4f1258a2ecdb7821e7dc1992c1c3226734..e78f3266d664faca0bc0b9659a36278920a42eee 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/NestedScrollToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/NestedScrollToScene.kt
@@ -179,7 +179,8 @@ private fun scenePriorityNestedScrollConnection(
     bottomOrRightBehavior: NestedScrollBehavior,
 ) =
     SceneNestedScrollHandler(
-            gestureHandler = layoutImpl.gestureHandler(orientation = orientation),
+            layoutImpl = layoutImpl,
+            orientation = orientation,
             topOrLeftBehavior = topOrLeftBehavior,
             bottomOrRightBehavior = bottomOrRightBehavior,
         )
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneGestureHandler.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneGestureHandler.kt
index 91decf4d8b7ebd1c7c67b45fc66a0d602e86f0c4..338557d0942e9bb29eb523a642572d93010aa6a6 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneGestureHandler.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneGestureHandler.kt
@@ -41,14 +41,9 @@ internal class SceneGestureHandler(
     internal val orientation: Orientation,
     private val coroutineScope: CoroutineScope,
 ) {
+    private val layoutState = layoutImpl.state
     val draggable: DraggableHandler = SceneDraggableHandler(this)
 
-    internal var transitionState
-        get() = layoutImpl.state.transitionState
-        set(value) {
-            layoutImpl.state.transitionState = value
-        }
-
     private var _swipeTransition: SwipeTransition? = null
     internal var swipeTransition: SwipeTransition
         get() = _swipeTransition ?: error("SwipeTransition needs to be initialized")
@@ -57,27 +52,26 @@ internal class SceneGestureHandler(
         }
 
     private fun updateTransition(newTransition: SwipeTransition, force: Boolean = false) {
-        if (isDrivingTransition || force) transitionState = newTransition
+        if (isDrivingTransition || force) layoutState.startTransition(newTransition)
         swipeTransition = newTransition
     }
 
-    internal val currentScene: Scene
-        get() = layoutImpl.scene(transitionState.currentScene)
-
     internal val isDrivingTransition
-        get() = transitionState == _swipeTransition
+        get() = layoutState.transitionState == _swipeTransition
 
     /**
      * The velocity threshold at which the intent of the user is to swipe up or down. It is the same
      * as SwipeableV2Defaults.VelocityThreshold.
      */
-    internal val velocityThreshold = with(layoutImpl.density) { 125.dp.toPx() }
+    internal val velocityThreshold: Float
+        get() = with(layoutImpl.density) { 125.dp.toPx() }
 
     /**
      * The positional threshold at which the intent of the user is to swipe to the next scene. It is
      * the same as SwipeableV2Defaults.PositionalThreshold.
      */
-    private val positionalThreshold = with(layoutImpl.density) { 56.dp.toPx() }
+    private val positionalThreshold
+        get() = with(layoutImpl.density) { 56.dp.toPx() }
 
     internal var gestureWithPriority: Any? = null
 
@@ -98,18 +92,18 @@ internal class SceneGestureHandler(
             return
         }
 
-        val transition = transitionState
-        if (transition is TransitionState.Transition) {
+        val transitionState = layoutState.transitionState
+        if (transitionState is TransitionState.Transition) {
             // TODO(b/290184746): Better handle interruptions here if state != idle.
             Log.w(
                 TAG,
                 "start from TransitionState.Transition is not fully supported: from" +
-                    " ${transition.fromScene} to ${transition.toScene} " +
-                    "(progress ${transition.progress})"
+                    " ${transitionState.fromScene} to ${transitionState.toScene} " +
+                    "(progress ${transitionState.progress})"
             )
         }
 
-        val fromScene = currentScene
+        val fromScene = layoutImpl.scene(transitionState.currentScene)
         setCurrentActions(fromScene, startedPosition, pointersDown)
 
         val (targetScene, distance) =
@@ -364,7 +358,7 @@ internal class SceneGestureHandler(
                     findTargetSceneAndDistanceStrict(fromScene, velocity)
                         ?: run {
                             // We will not animate
-                            transitionState = TransitionState.Idle(fromScene.key)
+                            layoutState.finishTransition(swipeTransition, idleScene = fromScene.key)
                             return
                         }
 
@@ -439,14 +433,7 @@ internal class SceneGestureHandler(
                 )
 
                 swipeTransition.finishOffsetAnimation()
-
-                // Now that the animation is done, the state should be idle. Note that if the state
-                // was changed since this animation started, some external code changed it and we
-                // shouldn't do anything here. Note also that this job will be cancelled in the case
-                // where the user intercepts this swipe.
-                if (isDrivingTransition) {
-                    transitionState = TransitionState.Idle(targetScene)
-                }
+                layoutState.finishTransition(swipeTransition, targetScene)
             }
         }
     }
@@ -539,10 +526,14 @@ private class SceneDraggableHandler(
 }
 
 internal class SceneNestedScrollHandler(
-    private val gestureHandler: SceneGestureHandler,
+    private val layoutImpl: SceneTransitionLayoutImpl,
+    private val orientation: Orientation,
     private val topOrLeftBehavior: NestedScrollBehavior,
     private val bottomOrRightBehavior: NestedScrollBehavior,
 ) : NestedScrollHandler {
+    private val layoutState = layoutImpl.state
+    private val gestureHandler = layoutImpl.gestureHandler(orientation)
+
     override val connection: PriorityNestedScrollConnection = nestedScrollConnection()
 
     private fun nestedScrollConnection(): PriorityNestedScrollConnection {
@@ -553,7 +544,7 @@ internal class SceneNestedScrollHandler(
         val actionUpOrLeft =
             Swipe(
                 direction =
-                    when (gestureHandler.orientation) {
+                    when (orientation) {
                         Orientation.Horizontal -> SwipeDirection.Left
                         Orientation.Vertical -> SwipeDirection.Up
                     },
@@ -563,7 +554,7 @@ internal class SceneNestedScrollHandler(
         val actionDownOrRight =
             Swipe(
                 direction =
-                    when (gestureHandler.orientation) {
+                    when (orientation) {
                         Orientation.Horizontal -> SwipeDirection.Right
                         Orientation.Vertical -> SwipeDirection.Down
                     },
@@ -571,7 +562,7 @@ internal class SceneNestedScrollHandler(
             )
 
         fun hasNextScene(amount: Float): Boolean {
-            val fromScene = gestureHandler.currentScene
+            val fromScene = layoutImpl.scene(layoutState.transitionState.currentScene)
             val nextScene =
                 when {
                     amount < 0f -> fromScene.userActions[actionUpOrLeft]
@@ -582,7 +573,7 @@ internal class SceneNestedScrollHandler(
         }
 
         return PriorityNestedScrollConnection(
-            orientation = gestureHandler.orientation,
+            orientation = orientation,
             canStartPreScroll = { offsetAvailable, offsetBeforeStart ->
                 canChangeScene = offsetBeforeStart == 0f
 
@@ -590,8 +581,9 @@ internal class SceneNestedScrollHandler(
                     canChangeScene && gestureHandler.isDrivingTransition && offsetAvailable != 0f
                 if (!canInterceptSwipeTransition) return@PriorityNestedScrollConnection false
 
-                val progress = gestureHandler.swipeTransition.progress
-                val threshold = gestureHandler.layoutImpl.transitionInterceptionThreshold
+                val swipeTransition = gestureHandler.swipeTransition
+                val progress = swipeTransition.progress
+                val threshold = layoutImpl.transitionInterceptionThreshold
                 fun isProgressCloseTo(value: Float) = (progress - value).absoluteValue <= threshold
 
                 // The transition is always between 0 and 1. If it is close to either of these
@@ -599,9 +591,8 @@ internal class SceneNestedScrollHandler(
                 // The progress value can go beyond this range in the case of overscroll.
                 val shouldSnapToIdle = isProgressCloseTo(0f) || isProgressCloseTo(1f)
                 if (shouldSnapToIdle) {
-                    gestureHandler.swipeTransition.cancelOffsetAnimation()
-                    gestureHandler.transitionState =
-                        TransitionState.Idle(gestureHandler.swipeTransition.currentScene)
+                    swipeTransition.cancelOffsetAnimation()
+                    layoutState.finishTransition(swipeTransition, swipeTransition.currentScene)
                 }
 
                 // Start only if we cannot consume this event
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
index 239971ff6be825e3c7d6c2ec9ce8286686ba7484..3608e374fdbc83ea3fdd8ca39d2ec108dd81a9a7 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
@@ -19,6 +19,8 @@ package com.android.compose.animation.scene
 import androidx.annotation.FloatRange
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.runtime.Composable
+import androidx.compose.runtime.LaunchedEffect
+import androidx.compose.runtime.SideEffect
 import androidx.compose.runtime.Stable
 import androidx.compose.runtime.State
 import androidx.compose.runtime.remember
@@ -27,6 +29,7 @@ import androidx.compose.ui.Modifier
 import androidx.compose.ui.graphics.Shape
 import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
 import androidx.compose.ui.platform.LocalDensity
+import kotlinx.coroutines.channels.Channel
 
 /**
  * [SceneTransitionLayout] is a container that automatically animates its content whenever
@@ -266,24 +269,45 @@ internal fun SceneTransitionLayoutForTesting(
     val coroutineScope = rememberCoroutineScope()
     val layoutImpl = remember {
         SceneTransitionLayoutImpl(
+                state = state as SceneTransitionLayoutStateImpl,
                 onChangeScene = onChangeScene,
-                builder = scenes,
-                transitions = transitions,
-                state = state,
                 density = density,
                 edgeDetector = edgeDetector,
                 transitionInterceptionThreshold = transitionInterceptionThreshold,
+                builder = scenes,
                 coroutineScope = coroutineScope,
             )
             .also { onLayoutImpl?.invoke(it) }
     }
 
-    layoutImpl.onChangeScene = onChangeScene
-    layoutImpl.transitions = transitions
-    layoutImpl.density = density
-    layoutImpl.edgeDetector = edgeDetector
+    val targetSceneChannel = remember { Channel<SceneKey>(Channel.CONFLATED) }
+    SideEffect {
+        if (state != layoutImpl.state) {
+            error(
+                "This SceneTransitionLayout was bound to a different SceneTransitionLayoutState" +
+                    " that was used when creating it, which is not supported"
+            )
+        }
+
+        layoutImpl.onChangeScene = onChangeScene
+        (state as SceneTransitionLayoutStateImpl).transitions = transitions
+        layoutImpl.density = density
+        layoutImpl.edgeDetector = edgeDetector
+        layoutImpl.updateScenes(scenes)
+
+        state.transitions = transitions
+
+        targetSceneChannel.trySend(currentScene)
+    }
+
+    LaunchedEffect(targetSceneChannel) {
+        for (newKey in targetSceneChannel) {
+            // Inspired by AnimateAsState.kt: let's poll the last value to avoid being one frame
+            // late.
+            val newKey = targetSceneChannel.tryReceive().getOrNull() ?: newKey
+            animateToScene(layoutImpl.state, newKey)
+        }
+    }
 
-    layoutImpl.setScenes(scenes)
-    layoutImpl.setCurrentScene(currentScene)
     layoutImpl.Content(modifier)
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index 00e33e24c41e8e050f3baaf3102a14033bfd6992..c99c3250bbb108206d29c90f81a9b89f490d7596 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -22,13 +22,8 @@ import androidx.compose.foundation.layout.Box
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.DisposableEffect
 import androidx.compose.runtime.LaunchedEffect
-import androidx.compose.runtime.SideEffect
 import androidx.compose.runtime.Stable
-import androidx.compose.runtime.getValue
 import androidx.compose.runtime.key
-import androidx.compose.runtime.mutableStateOf
-import androidx.compose.runtime.remember
-import androidx.compose.runtime.setValue
 import androidx.compose.runtime.snapshots.SnapshotStateMap
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
@@ -40,36 +35,40 @@ import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.util.fastForEach
 import com.android.compose.ui.util.lerp
 import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.channels.Channel
 
 @Stable
 internal class SceneTransitionLayoutImpl(
-    onChangeScene: (SceneKey) -> Unit,
+    internal val state: SceneTransitionLayoutStateImpl,
+    internal var onChangeScene: (SceneKey) -> Unit,
+    internal var density: Density,
+    internal var edgeDetector: EdgeDetector,
+    internal var transitionInterceptionThreshold: Float,
     builder: SceneTransitionLayoutScope.() -> Unit,
-    transitions: SceneTransitions,
-    internal val state: SceneTransitionLayoutState,
-    density: Density,
-    edgeDetector: EdgeDetector,
-    transitionInterceptionThreshold: Float,
     coroutineScope: CoroutineScope,
 ) {
-    internal val scenes = SnapshotStateMap<SceneKey, Scene>()
+    internal val scenes = mutableMapOf<SceneKey, Scene>()
+
+    /**
+     * The map of [Element]s.
+     *
+     * Note that this map is *mutated* directly during composition, so it is a [SnapshotStateMap] to
+     * make sure that mutations are reverted if composition is cancelled.
+     */
     internal val elements = SnapshotStateMap<ElementKey, Element>()
 
-    /** The scenes that are "ready", i.e. they were composed and fully laid-out at least once. */
+    /**
+     * The scenes that are "ready", i.e. they were composed and fully laid-out at least once.
+     *
+     * Note that this map is *read* during composition, so it is a [SnapshotStateMap] to make sure
+     * that we recompose when modifications are made to this map.
+     */
     private val readyScenes = SnapshotStateMap<SceneKey, Boolean>()
 
-    internal var onChangeScene by mutableStateOf(onChangeScene)
-    internal var transitions by mutableStateOf(transitions)
-    internal var density: Density by mutableStateOf(density)
-    internal var edgeDetector by mutableStateOf(edgeDetector)
-    internal var transitionInterceptionThreshold by mutableStateOf(transitionInterceptionThreshold)
-
     private val horizontalGestureHandler: SceneGestureHandler
     private val verticalGestureHandler: SceneGestureHandler
 
     init {
-        setScenes(builder)
+        updateScenes(builder)
 
         // SceneGestureHandler must wait for the scenes to be initialized, in order to access the
         // current scene (required for SwipeTransition).
@@ -98,7 +97,7 @@ internal class SceneTransitionLayoutImpl(
         return scenes[key] ?: error("Scene $key is not configured")
     }
 
-    internal fun setScenes(builder: SceneTransitionLayoutScope.() -> Unit) {
+    internal fun updateScenes(builder: SceneTransitionLayoutScope.() -> Unit) {
         // Keep a reference of the current scenes. After processing [builder], the scenes that were
         // not configured will be removed.
         val scenesToRemove = scenes.keys.toMutableSet()
@@ -140,20 +139,6 @@ internal class SceneTransitionLayoutImpl(
         scenesToRemove.forEach { scenes.remove(it) }
     }
 
-    @Composable
-    internal fun setCurrentScene(key: SceneKey) {
-        val channel = remember { Channel<SceneKey>(Channel.CONFLATED) }
-        SideEffect { channel.trySend(key) }
-        LaunchedEffect(channel) {
-            for (newKey in channel) {
-                // Inspired by AnimateAsState.kt: let's poll the last value to avoid being one frame
-                // late.
-                val newKey = channel.tryReceive().getOrNull() ?: newKey
-                animateToScene(this@SceneTransitionLayoutImpl, newKey)
-            }
-        }
-    }
-
     @Composable
     @OptIn(ExperimentalComposeUiApi::class)
     internal fun Content(modifier: Modifier) {
@@ -171,14 +156,14 @@ internal class SceneTransitionLayoutImpl(
 
                     val width: Int
                     val height: Int
-                    val state = state.transitionState
-                    if (state !is TransitionState.Transition) {
+                    val transition = state.currentTransition
+                    if (transition == null) {
                         width = placeable.width
                         height = placeable.height
                     } else {
                         // Interpolate the size.
-                        val fromSize = scene(state.fromScene).targetSize
-                        val toSize = scene(state.toScene).targetSize
+                        val fromSize = scene(transition.fromScene).targetSize
+                        val toSize = scene(transition.toScene).targetSize
 
                         // Optimization: make sure we don't read state.progress if fromSize ==
                         // toSize to avoid running this code every frame when the layout size does
@@ -187,7 +172,7 @@ internal class SceneTransitionLayoutImpl(
                             width = fromSize.width
                             height = fromSize.height
                         } else {
-                            val size = lerp(fromSize, toSize, state.progress)
+                            val size = lerp(fromSize, toSize, transition.progress)
                             width = size.width.coerceAtLeast(0)
                             height = size.height.coerceAtLeast(0)
                         }
@@ -228,13 +213,12 @@ internal class SceneTransitionLayoutImpl(
 
                             scene.Content(
                                 Modifier.drawWithContent {
-                                    when (val state = state.transitionState) {
-                                        is TransitionState.Idle -> drawContent()
-                                        is TransitionState.Transition -> {
-                                            // Don't draw scenes that are not ready yet.
-                                            if (readyScenes.containsKey(key)) {
-                                                drawContent()
-                                            }
+                                    if (state.currentTransition == null) {
+                                        drawContent()
+                                    } else {
+                                        // Don't draw scenes that are not ready yet.
+                                        if (readyScenes.containsKey(key)) {
+                                            drawContent()
                                         }
                                     }
                                 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index 623725582a9d8c7fe5462b2e2bdfbe4aeb49c2b3..d1ba582d6c2313649cf02b9bfb52dc9f090d36a2 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -23,31 +23,32 @@ import androidx.compose.runtime.setValue
 
 /** The state of a [SceneTransitionLayout]. */
 @Stable
-class SceneTransitionLayoutState(initialScene: SceneKey) {
+sealed interface SceneTransitionLayoutState {
     /**
      * The current [TransitionState]. All values read here are backed by the Snapshot system.
      *
      * To observe those values outside of Compose/the Snapshot system, use
      * [SceneTransitionLayoutState.observableTransitionState] instead.
      */
-    var transitionState: TransitionState by mutableStateOf(TransitionState.Idle(initialScene))
+    val transitionState: TransitionState
+
+    /** The current transition, or `null` if we are idle. */
+    val currentTransition: TransitionState.Transition?
+        get() = transitionState as? TransitionState.Transition
 
     /**
-     * Whether we are transitioning, optionally restricting the check to the transition between
-     * [from] and [to].
+     * Whether we are transitioning. If [from] or [to] is empty, we will also check that they match
+     * the scenes we are animating from and/or to.
      */
-    fun isTransitioning(from: SceneKey? = null, to: SceneKey? = null): Boolean {
-        val transition = transitionState as? TransitionState.Transition ?: return false
-
-        return (from == null || transition.fromScene == from) &&
-            (to == null || transition.toScene == to)
-    }
+    fun isTransitioning(from: SceneKey? = null, to: SceneKey? = null): Boolean
 
     /** Whether we are transitioning from [scene] to [other], or from [other] to [scene]. */
-    fun isTransitioningBetween(scene: SceneKey, other: SceneKey): Boolean {
-        return isTransitioning(from = scene, to = other) ||
-            isTransitioning(from = other, to = scene)
-    }
+    fun isTransitioningBetween(scene: SceneKey, other: SceneKey): Boolean
+}
+
+/** Create a new [SceneTransitionLayoutState] that is currently idle at scene [currentScene]. */
+fun SceneTransitionLayoutState(currentScene: SceneKey): SceneTransitionLayoutState {
+    return SceneTransitionLayoutStateImpl(currentScene, SceneTransitions.Empty)
 }
 
 @Stable
@@ -93,3 +94,50 @@ sealed interface TransitionState {
         abstract val isUserInputOngoing: Boolean
     }
 }
+
+internal class SceneTransitionLayoutStateImpl(
+    initialScene: SceneKey,
+    internal var transitions: SceneTransitions,
+) : SceneTransitionLayoutState {
+    override var transitionState: TransitionState by
+        mutableStateOf(TransitionState.Idle(initialScene))
+        private set
+
+    /**
+     * The current [transformationSpec] associated to [transitionState]. Accessing this value makes
+     * sense only if [transitionState] is a [TransitionState.Transition].
+     */
+    internal var transformationSpec: TransformationSpecImpl = TransformationSpec.Empty
+
+    override fun isTransitioning(from: SceneKey?, to: SceneKey?): Boolean {
+        val transition = currentTransition ?: return false
+        return (from == null || transition.fromScene == from) &&
+            (to == null || transition.toScene == to)
+    }
+
+    override fun isTransitioningBetween(scene: SceneKey, other: SceneKey): Boolean {
+        return isTransitioning(from = scene, to = other) ||
+            isTransitioning(from = other, to = scene)
+    }
+
+    /** Start a new [transition], instantly interrupting any ongoing transition if there was one. */
+    internal fun startTransition(transition: TransitionState.Transition) {
+        // Compute the [TransformationSpec] when the transition starts.
+        transformationSpec =
+            transitions
+                .transitionSpec(transition.fromScene, transition.toScene)
+                .transformationSpec()
+
+        transitionState = transition
+    }
+
+    /**
+     * Notify that [transition] was finished and that we should settle to [idleScene]. This will do
+     * nothing if [transition] was interrupted since it was started.
+     */
+    internal fun finishTransition(transition: TransitionState.Transition, idleScene: SceneKey) {
+        if (transitionState == transition) {
+            transitionState = TransitionState.Idle(idleScene)
+        }
+    }
+}
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
index f91895bb0e05f40f38f0e85d8a9c2eb1e3e34f9f..3a55567d69bba04bf7a1983bdf2643fc1f020bbd 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
@@ -18,11 +18,9 @@ package com.android.compose.animation.scene
 
 import androidx.compose.animation.core.AnimationSpec
 import androidx.compose.animation.core.snap
-import androidx.compose.runtime.Stable
 import androidx.compose.ui.geometry.Offset
 import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.util.fastForEach
-import androidx.compose.ui.util.fastMap
 import com.android.compose.animation.scene.transformation.AnchoredSize
 import com.android.compose.animation.scene.transformation.AnchoredTranslate
 import com.android.compose.animation.scene.transformation.DrawScale
@@ -36,16 +34,17 @@ import com.android.compose.animation.scene.transformation.Transformation
 import com.android.compose.animation.scene.transformation.Translate
 
 /** The transitions configuration of a [SceneTransitionLayout]. */
-class SceneTransitions(
-    internal val transitionSpecs: List<TransitionSpec>,
+class SceneTransitions
+internal constructor(
+    internal val transitionSpecs: List<TransitionSpecImpl>,
 ) {
-    private val cache = mutableMapOf<SceneKey, MutableMap<SceneKey, TransitionSpec>>()
+    private val cache = mutableMapOf<SceneKey, MutableMap<SceneKey, TransitionSpecImpl>>()
 
-    internal fun transitionSpec(from: SceneKey, to: SceneKey): TransitionSpec {
+    internal fun transitionSpec(from: SceneKey, to: SceneKey): TransitionSpecImpl {
         return cache.getOrPut(from) { mutableMapOf() }.getOrPut(to) { findSpec(from, to) }
     }
 
-    private fun findSpec(from: SceneKey, to: SceneKey): TransitionSpec {
+    private fun findSpec(from: SceneKey, to: SceneKey): TransitionSpecImpl {
         val spec = transition(from, to) { it.from == from && it.to == to }
         if (spec != null) {
             return spec
@@ -53,7 +52,7 @@ class SceneTransitions(
 
         val reversed = transition(from, to) { it.from == to && it.to == from }
         if (reversed != null) {
-            return reversed.reverse()
+            return reversed.reversed()
         }
 
         val relaxedSpec =
@@ -67,16 +66,16 @@ class SceneTransitions(
         return transition(from, to) {
                 (it.from == to && it.to == null) || (it.to == from && it.from == null)
             }
-            ?.reverse()
+            ?.reversed()
             ?: defaultTransition(from, to)
     }
 
     private fun transition(
         from: SceneKey,
         to: SceneKey,
-        filter: (TransitionSpec) -> Boolean,
-    ): TransitionSpec? {
-        var match: TransitionSpec? = null
+        filter: (TransitionSpecImpl) -> Boolean,
+    ): TransitionSpecImpl? {
+        var match: TransitionSpecImpl? = null
         transitionSpecs.fastForEach { spec ->
             if (filter(spec)) {
                 if (match != null) {
@@ -89,28 +88,88 @@ class SceneTransitions(
     }
 
     private fun defaultTransition(from: SceneKey, to: SceneKey) =
-        TransitionSpec(from, to, emptyList(), snap())
+        TransitionSpecImpl(from, to, TransformationSpec.EmptyProvider)
+
+    companion object {
+        val Empty = SceneTransitions(transitionSpecs = emptyList())
+    }
 }
 
 /** The definition of a transition between [from] and [to]. */
-@Stable
-data class TransitionSpec(
-    val from: SceneKey?,
-    val to: SceneKey?,
-    val transformations: List<Transformation>,
-    val spec: AnimationSpec<Float>,
-) {
-    // TODO(b/302300957): Make sure this cache does not infinitely grow.
-    private val cache = mutableMapOf<ElementKey, MutableMap<SceneKey, ElementTransformations>>()
+interface TransitionSpec {
+    /**
+     * The scene we are transitioning from. If `null`, this spec can be used to animate from any
+     * scene.
+     */
+    val from: SceneKey?
+
+    /**
+     * The scene we are transitioning to. If `null`, this spec can be used to animate from any
+     * scene.
+     */
+    val to: SceneKey?
+
+    /**
+     * Return a reversed version of this [TransitionSpec] for a transition going from [to] to
+     * [from].
+     */
+    fun reversed(): TransitionSpec
+
+    /*
+     * The [TransformationSpec] associated to this [TransitionSpec].
+     *
+     * Note that this is called once every a transition associated to this [TransitionSpec] is
+     * started.
+     */
+    fun transformationSpec(): TransformationSpec
+}
+
+interface TransformationSpec {
+    /** The [AnimationSpec] used to animate the associated transition progress. */
+    val progressSpec: AnimationSpec<Float>
+
+    /** The list of [Transformation] applied to elements during this transition. */
+    val transformations: List<Transformation>
+
+    companion object {
+        internal val Empty =
+            TransformationSpecImpl(progressSpec = snap(), transformations = emptyList())
+        internal val EmptyProvider = { Empty }
+    }
+}
 
-    internal fun reverse(): TransitionSpec {
-        return copy(
+internal class TransitionSpecImpl(
+    override val from: SceneKey?,
+    override val to: SceneKey?,
+    private val transformationSpec: () -> TransformationSpecImpl,
+) : TransitionSpec {
+    override fun reversed(): TransitionSpecImpl {
+        return TransitionSpecImpl(
             from = to,
             to = from,
-            transformations = transformations.fastMap { it.reverse() },
+            transformationSpec = {
+                val reverse = transformationSpec.invoke()
+                TransformationSpecImpl(
+                    progressSpec = reverse.progressSpec,
+                    transformations = reverse.transformations.map { it.reversed() }
+                )
+            }
         )
     }
 
+    override fun transformationSpec(): TransformationSpecImpl = this.transformationSpec.invoke()
+}
+
+/**
+ * An implementation of [TransformationSpec] that allows the quick retrieval of an element
+ * [ElementTransformations].
+ */
+internal class TransformationSpecImpl(
+    override val progressSpec: AnimationSpec<Float>,
+    override val transformations: List<Transformation>,
+) : TransformationSpec {
+    private val cache = mutableMapOf<ElementKey, MutableMap<SceneKey, ElementTransformations>>()
+
     internal fun transformations(element: ElementKey, scene: SceneKey): ElementTransformations {
         return cache
             .getOrPut(element) { mutableMapOf() }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeToScene.kt
index 116a66673d0afdcd011b9d5894564d14d4fb4842..0d3bc7d0cd85d73faf9f3d0e1e658e4b17feeedc 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeToScene.kt
@@ -27,7 +27,8 @@ internal fun Modifier.swipeToScene(gestureHandler: SceneGestureHandler): Modifie
     fun Scene.shouldEnableSwipes(orientation: Orientation): Boolean =
         userActions.keys.any { it is Swipe && it.direction.orientation == orientation }
 
-    val currentScene = gestureHandler.currentScene
+    val layoutImpl = gestureHandler.layoutImpl
+    val currentScene = layoutImpl.scene(layoutImpl.state.transitionState.currentScene)
     val orientation = gestureHandler.orientation
     val canSwipe = currentScene.shouldEnableSwipes(orientation)
     val canOppositeSwipe =
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
index 8f4a36e47212b30eee935aafcb8e9e9e924213c9..70468669297c8f274fc33ce0147e82abd2b3678e 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
@@ -44,7 +44,7 @@ internal fun transitionsImpl(
 }
 
 private class SceneTransitionsBuilderImpl : SceneTransitionsBuilder {
-    val transitionSpecs = mutableListOf<TransitionSpec>()
+    val transitionSpecs = mutableListOf<TransitionSpecImpl>()
 
     override fun to(to: SceneKey, builder: TransitionBuilder.() -> Unit): TransitionSpec {
         return transition(from = null, to = to, builder)
@@ -63,14 +63,15 @@ private class SceneTransitionsBuilderImpl : SceneTransitionsBuilder {
         to: SceneKey?,
         builder: TransitionBuilder.() -> Unit,
     ): TransitionSpec {
-        val impl = TransitionBuilderImpl().apply(builder)
-        val spec =
-            TransitionSpec(
-                from,
-                to,
-                impl.transformations,
-                impl.spec,
+        fun transformationSpec(): TransformationSpecImpl {
+            val impl = TransitionBuilderImpl().apply(builder)
+            return TransformationSpecImpl(
+                progressSpec = impl.spec,
+                transformations = impl.transformations,
             )
+        }
+
+        val spec = TransitionSpecImpl(from, to, ::transformationSpec)
         transitionSpecs.add(spec)
         return spec
     }
@@ -143,7 +144,7 @@ internal class TransitionBuilderImpl : TransitionBuilder {
 
         transformations.add(
             if (reversed) {
-                transformation.reverse()
+                transformation.reversed()
             } else {
                 transformation
             }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transformation/Transformation.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transformation/Transformation.kt
index 206935558179a34657ce27c6e651fd51bd359afe..0cd11b9914c9760425bd459c4bf85e63b4481124 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transformation/Transformation.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transformation/Transformation.kt
@@ -42,7 +42,7 @@ sealed interface Transformation {
      * Reverse this transformation. This is called when we use Transition(from = A, to = B) when
      * animating from B to A and there is no Transition(from = B, to = A) defined.
      */
-    fun reverse(): Transformation = this
+    fun reversed(): Transformation = this
 }
 
 internal class SharedElementTransformation(
@@ -77,10 +77,10 @@ internal class RangedPropertyTransformation<T>(
     val delegate: PropertyTransformation<T>,
     override val range: TransformationRange,
 ) : PropertyTransformation<T> by delegate {
-    override fun reverse(): Transformation {
+    override fun reversed(): Transformation {
         return RangedPropertyTransformation(
-            delegate.reverse() as PropertyTransformation<T>,
-            range.reverse()
+            delegate.reversed() as PropertyTransformation<T>,
+            range.reversed()
         )
     }
 }
@@ -102,7 +102,7 @@ data class TransformationRange(
     }
 
     /** Reverse this range. */
-    fun reverse() = TransformationRange(start = reverseBound(end), end = reverseBound(start))
+    fun reversed() = TransformationRange(start = reverseBound(end), end = reverseBound(start))
 
     /** Get the progress of this range given the global [transitionProgress]. */
     fun progress(transitionProgress: Float): Float {
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneGestureHandlerTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneGestureHandlerTest.kt
index e6224df649cab2e0c792ce66343fe58dae5c55cc..d9ce5191f3d90641c0b5233edb251bf5eb01826e 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneGestureHandlerTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneGestureHandlerTest.kt
@@ -55,8 +55,8 @@ class SceneGestureHandlerTest {
     ) {
         private var internalCurrentScene: SceneKey by mutableStateOf(SceneA)
 
-        private val layoutState: SceneTransitionLayoutState =
-            SceneTransitionLayoutState(internalCurrentScene)
+        private val layoutState =
+            SceneTransitionLayoutStateImpl(internalCurrentScene, EmptyTestTransitions)
 
         val mutableUserActionsA: MutableMap<UserAction, SceneKey> =
             mutableMapOf(Swipe.Up to SceneB, Swipe.Down to SceneC)
@@ -93,36 +93,24 @@ class SceneGestureHandlerTest {
 
         private val layoutImpl =
             SceneTransitionLayoutImpl(
-                    onChangeScene = { internalCurrentScene = it },
-                    builder = scenesBuilder,
-                    transitions = EmptyTestTransitions,
                     state = layoutState,
+                    onChangeScene = { internalCurrentScene = it },
                     density = Density(1f),
                     edgeDetector = DefaultEdgeDetector,
                     transitionInterceptionThreshold = transitionInterceptionThreshold,
+                    builder = scenesBuilder,
                     coroutineScope = coroutineScope,
                 )
                 .apply { setScenesTargetSizeForTest(LAYOUT_SIZE) }
 
-        val sceneGestureHandler =
-            SceneGestureHandler(
-                layoutImpl = layoutImpl,
-                orientation = Orientation.Vertical,
-                coroutineScope = coroutineScope,
-            )
-
-        val horizontalSceneGestureHandler =
-            SceneGestureHandler(
-                layoutImpl = layoutImpl,
-                orientation = Orientation.Horizontal,
-                coroutineScope = coroutineScope,
-            )
-
+        val sceneGestureHandler = layoutImpl.gestureHandler(Orientation.Vertical)
+        val horizontalSceneGestureHandler = layoutImpl.gestureHandler(Orientation.Horizontal)
         val draggable = sceneGestureHandler.draggable
 
         fun nestedScrollConnection(nestedScrollBehavior: NestedScrollBehavior) =
             SceneNestedScrollHandler(
-                    gestureHandler = sceneGestureHandler,
+                    layoutImpl,
+                    orientation = sceneGestureHandler.orientation,
                     topOrLeftBehavior = nestedScrollBehavior,
                     bottomOrRightBehavior = nestedScrollBehavior,
                 )
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
index eeda8d46adfa30a06b1b293e1937b0c631e34961..c5b8d9ae0d10f5f28ae089657a25161adc2ca46d 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
@@ -40,8 +40,8 @@ class SceneTransitionLayoutStateTest {
 
     @Test
     fun isTransitioningTo_transition() {
-        val state = SceneTransitionLayoutState(TestScenes.SceneA)
-        state.transitionState = transition(from = TestScenes.SceneA, to = TestScenes.SceneB)
+        val state = SceneTransitionLayoutStateImpl(TestScenes.SceneA, SceneTransitions.Empty)
+        state.startTransition(transition(from = TestScenes.SceneA, to = TestScenes.SceneB))
 
         assertThat(state.isTransitioning()).isTrue()
         assertThat(state.isTransitioning(from = TestScenes.SceneA)).isTrue()
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
index fa94b25028a23e7b24e59930065fc7bfe93e654f..ef729921f4cdd4e746970e3a5ea21dbe49aec78c 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
@@ -55,7 +55,7 @@ class TransitionDslTest {
 
         assertThat(transitions.transitionSpecs)
             .comparingElementsUsing(
-                Correspondence.transforming<TransitionSpec, Pair<SceneKey?, SceneKey?>>(
+                Correspondence.transforming<TransitionSpecImpl, Pair<SceneKey?, SceneKey?>>(
                     { it?.from to it?.to },
                     "has (from, to) equal to"
                 )
@@ -70,8 +70,8 @@ class TransitionDslTest {
     @Test
     fun defaultTransitionSpec() {
         val transitions = transitions { from(TestScenes.SceneA, to = TestScenes.SceneB) }
-        val transition = transitions.transitionSpecs.single()
-        assertThat(transition.spec).isInstanceOf(SpringSpec::class.java)
+        val transformationSpec = transitions.transitionSpecs.single().transformationSpec()
+        assertThat(transformationSpec.progressSpec).isInstanceOf(SpringSpec::class.java)
     }
 
     @Test
@@ -79,9 +79,9 @@ class TransitionDslTest {
         val transitions = transitions {
             from(TestScenes.SceneA, to = TestScenes.SceneB) { spec = tween(durationMillis = 42) }
         }
-        val transition = transitions.transitionSpecs.single()
-        assertThat(transition.spec).isInstanceOf(TweenSpec::class.java)
-        assertThat((transition.spec as TweenSpec).durationMillis).isEqualTo(42)
+        val transformationSpec = transitions.transitionSpecs.single().transformationSpec()
+        assertThat(transformationSpec.progressSpec).isInstanceOf(TweenSpec::class.java)
+        assertThat((transformationSpec.progressSpec as TweenSpec).durationMillis).isEqualTo(42)
     }
 
     @Test
@@ -90,9 +90,10 @@ class TransitionDslTest {
             from(TestScenes.SceneA, to = TestScenes.SceneB) { fade(TestElements.Foo) }
         }
 
-        val transition = transitions.transitionSpecs.single()
-        assertThat(transition.transformations.size).isEqualTo(1)
-        assertThat(transition.transformations.single().range).isEqualTo(null)
+        val transformations =
+            transitions.transitionSpecs.single().transformationSpec().transformations
+        assertThat(transformations.size).isEqualTo(1)
+        assertThat(transformations.single().range).isEqualTo(null)
     }
 
     @Test
@@ -105,8 +106,9 @@ class TransitionDslTest {
             }
         }
 
-        val transition = transitions.transitionSpecs.single()
-        assertThat(transition.transformations)
+        val transformations =
+            transitions.transitionSpecs.single().transformationSpec().transformations
+        assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
                 TransformationRange(start = 0.1f, end = 0.8f),
@@ -127,8 +129,9 @@ class TransitionDslTest {
             }
         }
 
-        val transition = transitions.transitionSpecs.single()
-        assertThat(transition.transformations)
+        val transformations =
+            transitions.transitionSpecs.single().transformationSpec().transformations
+        assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
                 TransformationRange(start = 100 / 500f, end = 300 / 500f),
@@ -149,8 +152,9 @@ class TransitionDslTest {
             }
         }
 
-        val transition = transitions.transitionSpecs.single()
-        assertThat(transition.transformations)
+        val transformations =
+            transitions.transitionSpecs.single().transformationSpec().transformations
+        assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
                 TransformationRange(start = 1f - 0.8f, end = 1f - 0.1f),
@@ -170,9 +174,13 @@ class TransitionDslTest {
 
         // Fetch the transition from B to A, which will automatically reverse the transition from A
         // to B we defined.
-        val transition =
-            transitions.transitionSpec(from = TestScenes.SceneB, to = TestScenes.SceneA)
-        assertThat(transition.transformations)
+        val transformations =
+            transitions
+                .transitionSpec(from = TestScenes.SceneB, to = TestScenes.SceneA)
+                .transformationSpec()
+                .transformations
+
+        assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
                 TransformationRange(start = 1f - 0.8f, end = 1f - 0.1f),