diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java
index 8fadf9e295abb21db3b23163d0b3833ce2916c6d..5c99c679c88c7956a29f1ec24ad3b394e4ed8c26 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java
@@ -15,21 +15,20 @@
  */
 package com.android.cts.net.hostside;
 
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.resetMeteredNetwork;
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setupMeteredNetwork;
+import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setupActiveNetworkMeteredness;
 import static com.android.cts.net.hostside.Property.METERED_NETWORK;
 import static com.android.cts.net.hostside.Property.NON_METERED_NETWORK;
 
 import android.util.ArraySet;
-import android.util.Pair;
 
 import com.android.compatibility.common.util.BeforeAfterRule;
+import com.android.compatibility.common.util.ThrowingRunnable;
 
 import org.junit.runner.Description;
 import org.junit.runners.model.Statement;
 
 public class MeterednessConfigurationRule extends BeforeAfterRule {
-    private Pair<String, Boolean> mSsidAndInitialMeteredness;
+    private ThrowingRunnable mMeterednessResetter;
 
     @Override
     public void onBefore(Statement base, Description description) throws Throwable {
@@ -48,13 +47,13 @@ public class MeterednessConfigurationRule extends BeforeAfterRule {
     }
 
     public void configureNetworkMeteredness(boolean metered) throws Exception {
-        mSsidAndInitialMeteredness = setupMeteredNetwork(metered);
+        mMeterednessResetter = setupActiveNetworkMeteredness(metered);
     }
 
     public void resetNetworkMeteredness() throws Exception {
-        if (mSsidAndInitialMeteredness != null) {
-            resetMeteredNetwork(mSsidAndInitialMeteredness.first,
-                    mSsidAndInitialMeteredness.second);
+        if (mMeterednessResetter != null) {
+            mMeterednessResetter.run();
+            mMeterednessResetter = null;
         }
     }
 }
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
index 2ac29e77ff84a340c794084521bc9b1a707bb495..955317bbf681be319c7e512860678909ffc24438 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
@@ -17,16 +17,13 @@
 package com.android.cts.net.hostside;
 
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
+
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.canChangeActiveNetworkMeteredness;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setRestrictBackground;
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.isActiveNetworkMetered;
 import static com.android.cts.net.hostside.Property.BATTERY_SAVER_MODE;
 import static com.android.cts.net.hostside.Property.DATA_SAVER_MODE;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.junit.Assume.assumeTrue;
 
@@ -186,7 +183,7 @@ public class NetworkCallbackTest extends AbstractRestrictBackgroundNetworkTestCa
     public void setUp() throws Exception {
         super.setUp();
 
-        assumeTrue(isActiveNetworkMetered(true) || canChangeActiveNetworkMeteredness());
+        assumeTrue(canChangeActiveNetworkMeteredness());
 
         registerBroadcastReceiver();
 
@@ -198,13 +195,13 @@ public class NetworkCallbackTest extends AbstractRestrictBackgroundNetworkTestCa
         setBatterySaverMode(false);
         setRestrictBackground(false);
 
-        // Make wifi a metered network.
+        // Mark network as metered.
         mMeterednessConfiguration.configureNetworkMeteredness(true);
 
         // Register callback
         registerNetworkCallback((INetworkCallback.Stub) mTestNetworkCallback);
-        // Once the wifi is marked as metered, the wifi will reconnect. Wait for onAvailable()
-        // callback to ensure wifi is connected before the test and store the default network.
+        // Wait for onAvailable() callback to ensure network is available before the test
+        // and store the default network.
         mNetwork = mTestNetworkCallback.expectAvailableCallbackAndGetNetwork();
         // Check that the network is metered.
         mTestNetworkCallback.expectCapabilitiesCallbackEventually(mNetwork,
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
index 3041dfa76b4d6ef4a1318cddcd3f1006c4b9cedc..e05fbea47cdcb87972893dd7c9bc867115217432 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
@@ -20,14 +20,17 @@ import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_DISABLE
 import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_ENABLED;
 import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_WHITELISTED;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
 import static com.android.cts.net.hostside.AbstractRestrictBackgroundNetworkTestCase.TAG;
+import static com.android.cts.net.hostside.AbstractRestrictBackgroundNetworkTestCase.TEST_PKG;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -41,24 +44,30 @@ import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.wifi.WifiManager;
 import android.os.Process;
+import android.telephony.SubscriptionManager;
+import android.telephony.SubscriptionPlan;
 import android.text.TextUtils;
 import android.util.Log;
-import android.util.Pair;
+
+import androidx.test.platform.app.InstrumentationRegistry;
 
 import com.android.compatibility.common.util.AppStandbyUtils;
 import com.android.compatibility.common.util.BatteryUtils;
+import com.android.compatibility.common.util.ThrowingRunnable;
 
+import java.time.Period;
+import java.time.ZonedDateTime;
+import java.util.Arrays;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
-import androidx.test.platform.app.InstrumentationRegistry;
-
 public class NetworkPolicyTestUtils {
 
     private static final int TIMEOUT_CHANGE_METEREDNESS_MS = 10_000;
 
     private static ConnectivityManager mCm;
     private static WifiManager mWm;
+    private static SubscriptionManager mSm;
 
     private static Boolean mBatterySaverSupported;
     private static Boolean mDataSaverSupported;
@@ -135,16 +144,40 @@ public class NetworkPolicyTestUtils {
     }
 
     public static boolean canChangeActiveNetworkMeteredness() {
-        final Network activeNetwork = getConnectivityManager().getActiveNetwork();
-        final NetworkCapabilities networkCapabilities
-                = getConnectivityManager().getNetworkCapabilities(activeNetwork);
-        return networkCapabilities.hasTransport(TRANSPORT_WIFI);
+        final NetworkCapabilities networkCapabilities = getActiveNetworkCapabilities();
+        return networkCapabilities.hasTransport(TRANSPORT_WIFI)
+                || networkCapabilities.hasTransport(TRANSPORT_CELLULAR);
     }
 
-    public static Pair<String, Boolean> setupMeteredNetwork(boolean metered) throws Exception {
+    /**
+     * Updates the meteredness of the active network. Right now we can only change meteredness
+     * of either Wifi or cellular network, so if the active network is not either of these, this
+     * will throw an exception.
+     *
+     * @return a {@link ThrowingRunnable} object that can used to reset the meteredness change
+     *         made by this method.
+     */
+    public static ThrowingRunnable setupActiveNetworkMeteredness(boolean metered) throws Exception {
         if (isActiveNetworkMetered(metered)) {
             return null;
         }
+        final NetworkCapabilities networkCapabilities = getActiveNetworkCapabilities();
+        if (networkCapabilities.hasTransport(TRANSPORT_WIFI)) {
+            final String ssid = getWifiSsid();
+            setWifiMeteredStatus(ssid, metered);
+            return () -> setWifiMeteredStatus(ssid, !metered);
+        } else if (networkCapabilities.hasTransport(TRANSPORT_CELLULAR)) {
+            final int subId = SubscriptionManager.getActiveDataSubscriptionId();
+            setCellularMeteredStatus(subId, metered);
+            return () -> setCellularMeteredStatus(subId, !metered);
+        } else {
+            // Right now, we don't have a way to change meteredness of networks other
+            // than Wi-Fi or Cellular, so just throw an exception.
+            throw new IllegalStateException("Can't change meteredness of current active network");
+        }
+    }
+
+    private static String getWifiSsid() {
         final boolean isLocationEnabled = isLocationEnabled();
         try {
             if (!isLocationEnabled) {
@@ -152,8 +185,7 @@ public class NetworkPolicyTestUtils {
             }
             final String ssid = unquoteSSID(getWifiManager().getConnectionInfo().getSSID());
             assertNotEquals(WifiManager.UNKNOWN_SSID, ssid);
-            setWifiMeteredStatus(ssid, metered);
-            return Pair.create(ssid, !metered);
+            return ssid;
         } finally {
             // Reset the location enabled state
             if (!isLocationEnabled) {
@@ -162,11 +194,13 @@ public class NetworkPolicyTestUtils {
         }
     }
 
-    public static void resetMeteredNetwork(String ssid, boolean metered) throws Exception {
-        setWifiMeteredStatus(ssid, metered);
+    private static NetworkCapabilities getActiveNetworkCapabilities() {
+        final Network activeNetwork = getConnectivityManager().getActiveNetwork();
+        assertNotNull("No active network available", activeNetwork);
+        return getConnectivityManager().getNetworkCapabilities(activeNetwork);
     }
 
-    public static void setWifiMeteredStatus(String ssid, boolean metered) throws Exception {
+    private static void setWifiMeteredStatus(String ssid, boolean metered) throws Exception {
         assertFalse("SSID should not be empty", TextUtils.isEmpty(ssid));
         final String cmd = "cmd netpolicy set metered-network " + ssid + " " + metered;
         executeShellCommand(cmd);
@@ -174,15 +208,29 @@ public class NetworkPolicyTestUtils {
         assertActiveNetworkMetered(metered);
     }
 
-    public static void assertWifiMeteredStatus(String ssid, boolean expectedMeteredStatus) {
+    private static void assertWifiMeteredStatus(String ssid, boolean expectedMeteredStatus) {
         final String result = executeShellCommand("cmd netpolicy list wifi-networks");
         final String expectedLine = ssid + ";" + expectedMeteredStatus;
         assertTrue("Expected line: " + expectedLine + "; Actual result: " + result,
                 result.contains(expectedLine));
     }
 
+    private static void setCellularMeteredStatus(int subId, boolean metered) throws Exception {
+        setSubPlanOwner(subId, TEST_PKG);
+        try {
+            getSubscriptionManager().setSubscriptionPlans(subId,
+                    Arrays.asList(buildValidSubscriptionPlan(System.currentTimeMillis())));
+            final boolean unmeteredOverride = !metered;
+            getSubscriptionManager().setSubscriptionOverrideUnmetered(subId, unmeteredOverride,
+                    /*timeoutMillis=*/ 0);
+            assertActiveNetworkMetered(metered);
+        } finally {
+            setSubPlanOwner(subId, null);
+        }
+    }
+
     // Copied from cts/tests/tests/net/src/android/net/cts/ConnectivityManagerTest.java
-    public static void assertActiveNetworkMetered(boolean expectedMeteredStatus) throws Exception {
+    private static void assertActiveNetworkMetered(boolean expectedMeteredStatus) throws Exception {
         final CountDownLatch latch = new CountDownLatch(1);
         final NetworkCallback networkCallback = new NetworkCallback() {
             @Override
@@ -197,12 +245,29 @@ public class NetworkPolicyTestUtils {
         // with the current setting. Therefore, if the setting has already been changed,
         // this method will return right away, and if not it will wait for the setting to change.
         getConnectivityManager().registerDefaultNetworkCallback(networkCallback);
-        if (!latch.await(TIMEOUT_CHANGE_METEREDNESS_MS, TimeUnit.MILLISECONDS)) {
-            fail("Timed out waiting for active network metered status to change to "
-                    + expectedMeteredStatus + " ; network = "
-                    + getConnectivityManager().getActiveNetwork());
+        try {
+            if (!latch.await(TIMEOUT_CHANGE_METEREDNESS_MS, TimeUnit.MILLISECONDS)) {
+                fail("Timed out waiting for active network metered status to change to "
+                        + expectedMeteredStatus + "; network = "
+                        + getConnectivityManager().getActiveNetwork());
+            }
+        } finally {
+            getConnectivityManager().unregisterNetworkCallback(networkCallback);
         }
-        getConnectivityManager().unregisterNetworkCallback(networkCallback);
+    }
+
+    private static void setSubPlanOwner(int subId, String packageName) {
+        executeShellCommand("cmd netpolicy set sub-plan-owner " + subId + " " + packageName);
+    }
+
+    private static SubscriptionPlan buildValidSubscriptionPlan(long dataUsageTime) {
+        return SubscriptionPlan.Builder
+                .createRecurring(ZonedDateTime.parse("2007-03-14T00:00:00.000Z"),
+                        Period.ofMonths(1))
+                .setTitle("CTS")
+                .setDataLimit(1_000_000_000, SubscriptionPlan.LIMIT_BEHAVIOR_DISABLED)
+                .setDataUsage(500_000_000, dataUsageTime)
+                .build();
     }
 
     public static void setRestrictBackground(boolean enabled) {
@@ -274,6 +339,14 @@ public class NetworkPolicyTestUtils {
         return mWm;
     }
 
+    public static SubscriptionManager getSubscriptionManager() {
+        if (mSm == null) {
+            mSm = (SubscriptionManager) getContext().getSystemService(
+                    Context.TELEPHONY_SUBSCRIPTION_SERVICE);
+        }
+        return mSm;
+    }
+
     public static Context getContext() {
         return getInstrumentation().getContext();
     }
diff --git a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
index 590e17e5e5da0974c6076dbf3ccbe3d64b1213a9..1c9ff05a5e2a1805067303453fb777bbfe220683 100644
--- a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
+++ b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
@@ -138,7 +138,7 @@ public class MyService extends Service {
                     }
                 }
             };
-            mCm.registerNetworkCallback(makeWifiNetworkRequest(), mNetworkCallback);
+            mCm.registerNetworkCallback(makeNetworkRequest(), mNetworkCallback);
             try {
                 cb.asBinder().linkToDeath(() -> unregisterNetworkCallback(), 0);
             } catch (RemoteException e) {
@@ -156,9 +156,8 @@ public class MyService extends Service {
         }
       };
 
-    private NetworkRequest makeWifiNetworkRequest() {
+    private NetworkRequest makeNetworkRequest() {
         return new NetworkRequest.Builder()
-                .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
                 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                 .build();
     }