diff --git a/core/api/current.txt b/core/api/current.txt
index d48685ba6f66284bb58fcfbb7b22977ebff190b4..2cafc5b8255027c8a591a6e3a014c01b9613d0da 100644
--- a/core/api/current.txt
+++ b/core/api/current.txt
@@ -15222,6 +15222,7 @@ package android.graphics {
     method public int getMaxAnisotropy();
     method public void setFilterMode(int);
     method public void setMaxAnisotropy(@IntRange(from=1) int);
+    method @FlaggedApi("com.android.graphics.hwui.flags.gainmap_animations") public void setOverrideGainmap(@Nullable android.graphics.Gainmap);
     field public static final int FILTER_MODE_DEFAULT = 0; // 0x0
     field public static final int FILTER_MODE_LINEAR = 2; // 0x2
     field public static final int FILTER_MODE_NEAREST = 1; // 0x1
diff --git a/graphics/java/android/graphics/BitmapShader.java b/graphics/java/android/graphics/BitmapShader.java
index 5c065775eea23b396d7303aec75542ed84f2ac94..dcfff62459abc318b42316351b66b506a6988e6b 100644
--- a/graphics/java/android/graphics/BitmapShader.java
+++ b/graphics/java/android/graphics/BitmapShader.java
@@ -16,9 +16,13 @@
 
 package android.graphics;
 
+import android.annotation.FlaggedApi;
 import android.annotation.IntDef;
 import android.annotation.IntRange;
 import android.annotation.NonNull;
+import android.annotation.Nullable;
+
+import com.android.graphics.hwui.flags.Flags;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
@@ -32,6 +36,7 @@ public class BitmapShader extends Shader {
      * Prevent garbage collection.
      */
     /*package*/ Bitmap mBitmap;
+    private Gainmap mOverrideGainmap;
 
     private int mTileX;
     private int mTileY;
@@ -172,6 +177,24 @@ public class BitmapShader extends Shader {
         }
     }
 
+    /**
+     * Draws the BitmapShader with a copy of the given gainmap instead of the gainmap on the Bitmap
+     * the shader was constructed from
+     *
+     * @param overrideGainmap The gainmap to draw instead, null to use any gainmap on the Bitmap
+     */
+    @FlaggedApi(Flags.FLAG_GAINMAP_ANIMATIONS)
+    public void setOverrideGainmap(@Nullable Gainmap overrideGainmap) {
+        if (!Flags.gainmapAnimations()) throw new IllegalStateException("API not available");
+
+        if (overrideGainmap == null) {
+            mOverrideGainmap = null;
+        } else {
+            mOverrideGainmap = new Gainmap(overrideGainmap, overrideGainmap.getGainmapContents());
+        }
+        discardNativeInstance();
+    }
+
     /**
      * Returns the current max anisotropic filtering value configured by
      * {@link #setFilterMode(int)}. If {@link #setFilterMode(int)} is invoked this returns zero.
@@ -199,14 +222,9 @@ public class BitmapShader extends Shader {
 
         mIsDirectSampled = mRequestDirectSampling;
         mRequestDirectSampling = false;
-
-        if (mMaxAniso > 0) {
-            return nativeCreateWithMaxAniso(nativeMatrix, mBitmap.getNativeInstance(), mTileX,
-                    mTileY, mMaxAniso, mIsDirectSampled);
-        } else {
-            return nativeCreate(nativeMatrix, mBitmap.getNativeInstance(), mTileX, mTileY,
-                    enableLinearFilter, mIsDirectSampled);
-        }
+        return nativeCreate(nativeMatrix, mBitmap.getNativeInstance(), mTileX,
+                mTileY, mMaxAniso, enableLinearFilter, mIsDirectSampled,
+                mOverrideGainmap != null ? mOverrideGainmap.mNativePtr : 0);
     }
 
     /** @hide */
@@ -217,9 +235,7 @@ public class BitmapShader extends Shader {
     }
 
     private static native long nativeCreate(long nativeMatrix, long bitmapHandle,
-            int shaderTileModeX, int shaderTileModeY, boolean filter, boolean isDirectSampled);
-
-    private static native long nativeCreateWithMaxAniso(long nativeMatrix, long bitmapHandle,
-            int shaderTileModeX, int shaderTileModeY, int maxAniso, boolean isDirectSampled);
+            int shaderTileModeX, int shaderTileModeY, int maxAniso, boolean filter,
+            boolean isDirectSampled, long overrideGainmapHandle);
 }
 
diff --git a/libs/hwui/jni/Shader.cpp b/libs/hwui/jni/Shader.cpp
index 2c13ceb77b5204e83f052cc3fc905f9748a2bc3a..a952be020855384e7fcb331845cc688cca479394 100644
--- a/libs/hwui/jni/Shader.cpp
+++ b/libs/hwui/jni/Shader.cpp
@@ -65,21 +65,41 @@ static jlong Shader_getNativeFinalizer(JNIEnv*, jobject) {
     return static_cast<jlong>(reinterpret_cast<uintptr_t>(&Shader_safeUnref));
 }
 
-static jlong createBitmapShaderHelper(JNIEnv* env, jobject o, jlong matrixPtr, jlong bitmapHandle,
-                                      jint tileModeX, jint tileModeY, bool isDirectSampled,
-                                      const SkSamplingOptions& sampling) {
+///////////////////////////////////////////////////////////////////////////////////////////////
+
+static SkGainmapInfo sNoOpGainmap = {
+        .fGainmapRatioMin = {1.f, 1.f, 1.f, 1.0},
+        .fGainmapRatioMax = {1.f, 1.f, 1.f, 1.0},
+        .fGainmapGamma = {1.f, 1.f, 1.f, 1.f},
+        .fEpsilonSdr = {0.f, 0.f, 0.f, 1.0},
+        .fEpsilonHdr = {0.f, 0.f, 0.f, 1.0},
+        .fDisplayRatioSdr = 1.f,
+        .fDisplayRatioHdr = 1.f,
+};
+
+static jlong BitmapShader_constructor(JNIEnv* env, jobject o, jlong matrixPtr, jlong bitmapHandle,
+                                      jint tileModeX, jint tileModeY, jint maxAniso, bool filter,
+                                      bool isDirectSampled, jlong overrideGainmapPtr) {
+    SkSamplingOptions sampling = maxAniso > 0 ? SkSamplingOptions::Aniso(static_cast<int>(maxAniso))
+                                              : SkSamplingOptions(filter ? SkFilterMode::kLinear
+                                                                         : SkFilterMode::kNearest,
+                                                                  SkMipmapMode::kNone);
     const SkMatrix* matrix = reinterpret_cast<const SkMatrix*>(matrixPtr);
+    const Gainmap* gainmap = reinterpret_cast<Gainmap*>(overrideGainmapPtr);
     sk_sp<SkImage> image;
     if (bitmapHandle) {
         // Only pass a valid SkBitmap object to the constructor if the Bitmap exists. Otherwise,
         // we'll pass an empty SkBitmap to avoid crashing/excepting for compatibility.
         auto& bitmap = android::bitmap::toBitmap(bitmapHandle);
         image = bitmap.makeImage();
+        if (!gainmap && bitmap.hasGainmap()) {
+            gainmap = bitmap.gainmap().get();
+        }
 
-        if (!isDirectSampled && bitmap.hasGainmap()) {
-            sk_sp<SkShader> gainmapShader = MakeGainmapShader(
-                    image, bitmap.gainmap()->bitmap->makeImage(), bitmap.gainmap()->info,
-                    (SkTileMode)tileModeX, (SkTileMode)tileModeY, sampling);
+        if (!isDirectSampled && gainmap && gainmap->info != sNoOpGainmap) {
+            sk_sp<SkShader> gainmapShader =
+                    MakeGainmapShader(image, gainmap->bitmap->makeImage(), gainmap->info,
+                                      (SkTileMode)tileModeX, (SkTileMode)tileModeY, sampling);
             if (gainmapShader) {
                 if (matrix) {
                     gainmapShader = gainmapShader->makeWithLocalMatrix(*matrix);
@@ -111,26 +131,6 @@ static jlong createBitmapShaderHelper(JNIEnv* env, jobject o, jlong matrixPtr, j
 
 ///////////////////////////////////////////////////////////////////////////////////////////////
 
-static jlong BitmapShader_constructor(JNIEnv* env, jobject o, jlong matrixPtr, jlong bitmapHandle,
-                                      jint tileModeX, jint tileModeY, bool filter,
-                                      bool isDirectSampled) {
-    SkSamplingOptions sampling(filter ? SkFilterMode::kLinear : SkFilterMode::kNearest,
-                               SkMipmapMode::kNone);
-    return createBitmapShaderHelper(env, o, matrixPtr, bitmapHandle, tileModeX, tileModeY,
-                                    isDirectSampled, sampling);
-}
-
-static jlong BitmapShader_constructorWithMaxAniso(JNIEnv* env, jobject o, jlong matrixPtr,
-                                                  jlong bitmapHandle, jint tileModeX,
-                                                  jint tileModeY, jint maxAniso,
-                                                  bool isDirectSampled) {
-    auto sampling = SkSamplingOptions::Aniso(static_cast<int>(maxAniso));
-    return createBitmapShaderHelper(env, o, matrixPtr, bitmapHandle, tileModeX, tileModeY,
-                                    isDirectSampled, sampling);
-}
-
-///////////////////////////////////////////////////////////////////////////////////////////////
-
 static std::vector<SkColor4f> convertColorLongs(JNIEnv* env, jlongArray colorArray) {
     const size_t count = env->GetArrayLength(colorArray);
     const jlong* colorValues = env->GetLongArrayElements(colorArray, nullptr);
@@ -419,8 +419,7 @@ static const JNINativeMethod gShaderMethods[] = {
 };
 
 static const JNINativeMethod gBitmapShaderMethods[] = {
-        {"nativeCreate", "(JJIIZZ)J", (void*)BitmapShader_constructor},
-        {"nativeCreateWithMaxAniso", "(JJIIIZ)J", (void*)BitmapShader_constructorWithMaxAniso},
+        {"nativeCreate", "(JJIIIZZJ)J", (void*)BitmapShader_constructor},
 
 };