diff --git a/services/core/java/com/android/server/vcn/VcnContext.java b/services/core/java/com/android/server/vcn/VcnContext.java index d958222ea407040603c2ac90255f04e494dbd5f6..9213d96ad4ca60229a37961abe286c064614ba0f 100644 --- a/services/core/java/com/android/server/vcn/VcnContext.java +++ b/services/core/java/com/android/server/vcn/VcnContext.java @@ -18,6 +18,8 @@ package com.android.server.vcn; import android.annotation.NonNull; import android.content.Context; +import android.net.vcn.FeatureFlags; +import android.net.vcn.FeatureFlagsImpl; import android.os.Looper; import java.util.Objects; @@ -31,6 +33,7 @@ public class VcnContext { @NonNull private final Context mContext; @NonNull private final Looper mLooper; @NonNull private final VcnNetworkProvider mVcnNetworkProvider; + @NonNull private final FeatureFlags mFeatureFlags; private final boolean mIsInTestMode; public VcnContext( @@ -42,6 +45,9 @@ public class VcnContext { mLooper = Objects.requireNonNull(looper, "Missing looper"); mVcnNetworkProvider = Objects.requireNonNull(vcnNetworkProvider, "Missing networkProvider"); mIsInTestMode = isInTestMode; + + // Auto-generated class + mFeatureFlags = new FeatureFlagsImpl(); } @NonNull @@ -63,6 +69,11 @@ public class VcnContext { return mIsInTestMode; } + @NonNull + public FeatureFlags getFeatureFlags() { + return mFeatureFlags; + } + /** * Verifies that the caller is running on the VcnContext Thread. * diff --git a/services/core/java/com/android/server/vcn/VcnGatewayConnection.java b/services/core/java/com/android/server/vcn/VcnGatewayConnection.java index d480ddb092eb3544642886ece1388583d0e6eaaf..54c97dd37941740cbcb36cd986091b1cb4638cc7 100644 --- a/services/core/java/com/android/server/vcn/VcnGatewayConnection.java +++ b/services/core/java/com/android/server/vcn/VcnGatewayConnection.java @@ -1222,6 +1222,14 @@ public class VcnGatewayConnection extends StateMachine { @VisibleForTesting(visibility = Visibility.PRIVATE) void setSafeModeAlarm() { + final boolean isFlagSafeModeConfigEnabled = mVcnContext.getFeatureFlags().safeModeConfig(); + logVdbg("isFlagSafeModeConfigEnabled " + isFlagSafeModeConfigEnabled); + + if (isFlagSafeModeConfigEnabled && !mConnectionConfig.isSafeModeEnabled()) { + logVdbg("setSafeModeAlarm: safe mode disabled"); + return; + } + logVdbg("Setting safe mode alarm; mCurrentToken: " + mCurrentToken); // Only schedule a NEW alarm if none is already set. diff --git a/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java b/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java index 359ef83cfe7ca013f9030bd8d21a8c9975532342..cb3782173dc86bd626c29f1eb45f916151f0f4f5 100644 --- a/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java +++ b/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java @@ -117,6 +117,16 @@ public class VcnGatewayConnectionConfigTest { return buildTestConfig(UNDERLYING_NETWORK_TEMPLATES); } + // Public for use in VcnGatewayConnectionTest + public static VcnGatewayConnectionConfig.Builder newTestBuilderMinimal() { + final VcnGatewayConnectionConfig.Builder builder = newBuilder(); + for (int caps : EXPOSED_CAPS) { + builder.addExposedCapability(caps); + } + + return builder; + } + private static VcnGatewayConnectionConfig.Builder newBuilder() { // Append a unique identifier to the name prefix to guarantee that all created // VcnGatewayConnectionConfigs have a unique name (required by VcnConfig). diff --git a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java index 302af523a4bdb0e9827781ef3cac68358b4fd464..bf73198d10068b7b7c2aeb884edc28a7c87d3d50 100644 --- a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java +++ b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionConnectedStateTest.java @@ -75,6 +75,9 @@ import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; import com.android.server.vcn.VcnGatewayConnection.VcnChildSessionCallback; +import com.android.server.vcn.VcnGatewayConnection.VcnChildSessionConfiguration; +import com.android.server.vcn.VcnGatewayConnection.VcnIkeSession; +import com.android.server.vcn.VcnGatewayConnection.VcnNetworkAgent; import com.android.server.vcn.routeselection.UnderlyingNetworkRecord; import com.android.server.vcn.util.MtuUtils; @@ -651,6 +654,74 @@ public class VcnGatewayConnectionConnectedStateTest extends VcnGatewayConnection verifySafeModeStateAndCallbackFired(2 /* invocationCount */, true /* isInSafeMode */); } + private void verifySetSafeModeAlarm( + boolean safeModeEnabledByCaller, + boolean safeModeConfigFlagEnabled, + boolean expectingSafeModeEnabled) + throws Exception { + final VcnGatewayConnectionConfig config = + VcnGatewayConnectionConfigTest.newTestBuilderMinimal() + .enableSafeMode(safeModeEnabledByCaller) + .build(); + final VcnGatewayConnection.Dependencies deps = + mock(VcnGatewayConnection.Dependencies.class); + setUpWakeupMessage( + mSafeModeTimeoutAlarm, VcnGatewayConnection.SAFEMODE_TIMEOUT_ALARM, deps); + doReturn(safeModeConfigFlagEnabled).when(mFeatureFlags).safeModeConfig(); + + final VcnGatewayConnection connection = + new VcnGatewayConnection( + mVcnContext, + TEST_SUB_GRP, + TEST_SUBSCRIPTION_SNAPSHOT, + config, + mGatewayStatusCallback, + true /* isMobileDataEnabled */, + deps); + + connection.setSafeModeAlarm(); + + final int expectedCallCnt = expectingSafeModeEnabled ? 1 : 0; + verify(deps, times(expectedCallCnt)) + .newWakeupMessage( + eq(mVcnContext), + any(), + eq(VcnGatewayConnection.SAFEMODE_TIMEOUT_ALARM), + any()); + } + + @Test + public void testSafeModeEnabled_configFlagEnabled() throws Exception { + verifySetSafeModeAlarm( + true /* safeModeEnabledByCaller */, + true /* safeModeConfigFlagEnabled */, + true /* expectingSafeModeEnabled */); + } + + @Test + public void testSafeModeEnabled_configFlagDisabled() throws Exception { + verifySetSafeModeAlarm( + true /* safeModeEnabledByCaller */, + false /* safeModeConfigFlagEnabled */, + true /* expectingSafeModeEnabled */); + } + + @Test + public void testSafeModeDisabled_configFlagEnabled() throws Exception { + verifySetSafeModeAlarm( + false /* safeModeEnabledByCaller */, + true /* safeModeConfigFlagEnabled */, + false /* expectingSafeModeEnabled */); + } + + @Test + public void testSafeModeDisabled_configFlagDisabled() throws Exception { + verifySetSafeModeAlarm( + false /* safeModeEnabledByCaller */, + false /* safeModeConfigFlagEnabled */, + true /* expectingSafeModeEnabled */); + } + private Consumer<VcnNetworkAgent> setupNetworkAndGetUnwantedCallback() { triggerChildOpened(); mTestLooper.dispatchAll(); diff --git a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java index 5efbf598f941130dae4a21ac13a2277b8480c5f7..edced87427c8ee0f7c6b3f5a96fa3b91741076f7 100644 --- a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java +++ b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java @@ -53,6 +53,7 @@ import android.net.ipsec.ike.ChildSessionCallback; import android.net.ipsec.ike.IkeSessionCallback; import android.net.ipsec.ike.IkeSessionConfiguration; import android.net.ipsec.ike.IkeSessionConnectionInfo; +import android.net.vcn.FeatureFlags; import android.net.vcn.VcnGatewayConnectionConfig; import android.net.vcn.VcnGatewayConnectionConfigTest; import android.os.ParcelUuid; @@ -165,6 +166,7 @@ public class VcnGatewayConnectionTestBase { @NonNull protected final Context mContext; @NonNull protected final TestLooper mTestLooper; @NonNull protected final VcnNetworkProvider mVcnNetworkProvider; + @NonNull protected final FeatureFlags mFeatureFlags; @NonNull protected final VcnContext mVcnContext; @NonNull protected final VcnGatewayConnectionConfig mConfig; @NonNull protected final VcnGatewayStatusCallback mGatewayStatusCallback; @@ -190,6 +192,7 @@ public class VcnGatewayConnectionTestBase { mContext = mock(Context.class); mTestLooper = new TestLooper(); mVcnNetworkProvider = mock(VcnNetworkProvider.class); + mFeatureFlags = mock(FeatureFlags.class); mVcnContext = mock(VcnContext.class); mConfig = VcnGatewayConnectionConfigTest.buildTestConfig(); mGatewayStatusCallback = mock(VcnGatewayStatusCallback.class); @@ -222,6 +225,7 @@ public class VcnGatewayConnectionTestBase { doReturn(mContext).when(mVcnContext).getContext(); doReturn(mTestLooper.getLooper()).when(mVcnContext).getLooper(); doReturn(mVcnNetworkProvider).when(mVcnContext).getVcnNetworkProvider(); + doReturn(mFeatureFlags).when(mVcnContext).getFeatureFlags(); doReturn(mUnderlyingNetworkController) .when(mDeps) @@ -241,8 +245,15 @@ public class VcnGatewayConnectionTestBase { doReturn(ELAPSED_REAL_TIME).when(mDeps).getElapsedRealTime(); } + protected void setUpWakeupMessage( + @NonNull WakeupMessage msg, + @NonNull String cmdName, + VcnGatewayConnection.Dependencies deps) { + doReturn(msg).when(deps).newWakeupMessage(eq(mVcnContext), any(), eq(cmdName), any()); + } + private void setUpWakeupMessage(@NonNull WakeupMessage msg, @NonNull String cmdName) { - doReturn(msg).when(mDeps).newWakeupMessage(eq(mVcnContext), any(), eq(cmdName), any()); + setUpWakeupMessage(msg, cmdName, mDeps); } @Before