diff --git a/tests/unit/java/android/net/ConnectivityManagerTest.java b/tests/unit/java/android/net/ConnectivityManagerTest.java
index c804e1003887288083aaad6ec336b6118fdda9da..b8cd3f68d0db3f88e924a7540c64c7215339fe32 100644
--- a/tests/unit/java/android/net/ConnectivityManagerTest.java
+++ b/tests/unit/java/android/net/ConnectivityManagerTest.java
@@ -44,6 +44,7 @@ import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.nullable;
+import static org.mockito.Mockito.CALLS_REAL_METHODS;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.mock;
@@ -215,7 +216,8 @@ public class ConnectivityManagerTest {
     public void testCallbackRelease() throws Exception {
         ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
         NetworkRequest request = makeRequest(1);
-        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class);
+        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class,
+                CALLS_REAL_METHODS);
         Handler handler = new Handler(Looper.getMainLooper());
         ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
 
@@ -243,7 +245,8 @@ public class ConnectivityManagerTest {
         ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
         NetworkRequest req1 = makeRequest(1);
         NetworkRequest req2 = makeRequest(2);
-        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class);
+        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class,
+                CALLS_REAL_METHODS);
         Handler handler = new Handler(Looper.getMainLooper());
         ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
 
diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java
index b0a9b8a553223c77ebedac2c0729fce7ffd7b3f6..370179c2b11157e672cdfecc23b12f67c0a50fac 100644
--- a/tests/unit/java/android/net/nsd/NsdManagerTest.java
+++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java
@@ -20,12 +20,12 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 
 import android.content.Context;
 import android.os.Handler;
@@ -66,7 +66,7 @@ public class NsdManagerTest {
         MockitoAnnotations.initMocks(this);
 
         mServiceHandler = spy(MockServiceHandler.create(mContext));
-        when(mService.getMessenger()).thenReturn(new Messenger(mServiceHandler));
+        doReturn(new Messenger(mServiceHandler)).when(mService).getMessenger();
 
         mManager = makeManager();
     }
diff --git a/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt b/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt
index 25aa6266577e8241112c3a520a56944aeceecf45..78c8fa408cab41bbc55431afd6e00b624ea8e817 100644
--- a/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt
+++ b/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt
@@ -45,6 +45,7 @@ import org.mockito.ArgumentMatchers.anyInt
 import org.mockito.ArgumentMatchers.argThat
 import org.mockito.ArgumentMatchers.eq
 import org.mockito.Mockito.any
+import org.mockito.Mockito.doCallRealMethod
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.times
@@ -74,6 +75,10 @@ class MultinetworkPolicyTrackerTest {
         doReturn(Context.TELEPHONY_SERVICE).`when`(it)
                 .getSystemServiceName(TelephonyManager::class.java)
         doReturn(telephonyManager).`when`(it).getSystemService(Context.TELEPHONY_SERVICE)
+        if (it.getSystemService(TelephonyManager::class.java) == null) {
+            // Test is using mockito extended
+            doCallRealMethod().`when`(it).getSystemService(TelephonyManager::class.java)
+        }
         doReturn(subscriptionManager).`when`(it)
                 .getSystemService(Context.TELEPHONY_SUBSCRIPTION_SERVICE)
         doReturn(resolver).`when`(it).contentResolver
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 2aa95ee22ebe59dbb21ddad8010a8aa797f9c5e2..acea2863bab6b513402853e5bebd0e8513ea5bce 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -393,7 +393,7 @@ import kotlin.reflect.KClass;
 public class ConnectivityServiceTest {
     private static final String TAG = "ConnectivityServiceTest";
 
-    private static final int TIMEOUT_MS = 500;
+    private static final int TIMEOUT_MS = 2_000;
     // Broadcasts can take a long time to be delivered. The test will not wait for that long unless
     // there is a failure, so use a long timeout.
     private static final int BROADCAST_TIMEOUT_MS = 30_000;
diff --git a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java
index 38f6d7f3172d90dea1163f62967c69f1dc0d8c0f..4c80f6a21db544e99242071dddd67ce2fa87597e 100644
--- a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java
@@ -34,6 +34,7 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -114,8 +115,12 @@ public class MultipathPolicyTrackerTest {
     private boolean mRecurrenceRuleClockMocked;
 
     private <T> void mockService(String serviceName, Class<T> serviceClass, T service) {
-        when(mContext.getSystemServiceName(serviceClass)).thenReturn(serviceName);
-        when(mContext.getSystemService(serviceName)).thenReturn(service);
+        doReturn(serviceName).when(mContext).getSystemServiceName(serviceClass);
+        doReturn(service).when(mContext).getSystemService(serviceName);
+        if (mContext.getSystemService(serviceClass) == null) {
+            // Test is using mockito-extended
+            doCallRealMethod().when(mContext).getSystemService(serviceClass);
+        }
     }
 
     @Before
diff --git a/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java b/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java
index db6349533817f1b284596d87bfaa8c15010afe99..6bf6cc5493ba1eb5a508ba208b471afab9e1e7e0 100644
--- a/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java
@@ -51,6 +51,7 @@ import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
@@ -139,6 +140,10 @@ public class PermissionMonitorTest {
                 .thenReturn(Context.SYSTEM_CONFIG_SERVICE);
         when(mContext.getSystemService(Context.SYSTEM_CONFIG_SERVICE))
                 .thenReturn(mSystemConfigManager);
+        if (mContext.getSystemService(SystemConfigManager.class) == null) {
+            // Test is using mockito-extended
+            doCallRealMethod().when(mContext).getSystemService(SystemConfigManager.class);
+        }
         when(mSystemConfigManager.getSystemPermissionUids(anyString())).thenReturn(new int[0]);
         final Context asUserCtx = mock(Context.class, AdditionalAnswers.delegatesTo(mContext));
         doReturn(UserHandle.ALL).when(asUserCtx).getUser();
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index b725b826b14f0f342ab53ead0b0bccfa68978f44..6ff47aea34229010d7766db504f00558fb2475a5 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -39,6 +39,7 @@ import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
@@ -219,19 +220,11 @@ public class VpnTest {
 
         when(mContext.getPackageName()).thenReturn(TEST_VPN_PKG);
         when(mContext.getOpPackageName()).thenReturn(TEST_VPN_PKG);
-        when(mContext.getSystemServiceName(UserManager.class))
-                .thenReturn(Context.USER_SERVICE);
-        when(mContext.getSystemService(eq(Context.USER_SERVICE))).thenReturn(mUserManager);
-        when(mContext.getSystemService(eq(Context.APP_OPS_SERVICE))).thenReturn(mAppOps);
-        when(mContext.getSystemServiceName(NotificationManager.class))
-                .thenReturn(Context.NOTIFICATION_SERVICE);
-        when(mContext.getSystemService(eq(Context.NOTIFICATION_SERVICE)))
-                .thenReturn(mNotificationManager);
-        when(mContext.getSystemService(eq(Context.CONNECTIVITY_SERVICE)))
-                .thenReturn(mConnectivityManager);
-        when(mContext.getSystemServiceName(eq(ConnectivityManager.class)))
-                .thenReturn(Context.CONNECTIVITY_SERVICE);
-        when(mContext.getSystemService(eq(Context.IPSEC_SERVICE))).thenReturn(mIpSecManager);
+        mockService(UserManager.class, Context.USER_SERVICE, mUserManager);
+        mockService(AppOpsManager.class, Context.APP_OPS_SERVICE, mAppOps);
+        mockService(NotificationManager.class, Context.NOTIFICATION_SERVICE, mNotificationManager);
+        mockService(ConnectivityManager.class, Context.CONNECTIVITY_SERVICE, mConnectivityManager);
+        mockService(IpSecManager.class, Context.IPSEC_SERVICE, mIpSecManager);
         when(mContext.getString(R.string.config_customVpnAlwaysOnDisconnectedDialogComponent))
                 .thenReturn(Resources.getSystem().getString(
                         R.string.config_customVpnAlwaysOnDisconnectedDialogComponent));
@@ -259,6 +252,16 @@ public class VpnTest {
                 .thenReturn(tunnelResp);
     }
 
+    private <T> void mockService(Class<T> clazz, String name, T service) {
+        doReturn(service).when(mContext).getSystemService(name);
+        doReturn(name).when(mContext).getSystemServiceName(clazz);
+        if (mContext.getSystemService(clazz).getClass().equals(Object.class)) {
+            // Test is using mockito-extended (mContext uses Answers.RETURNS_DEEP_STUBS and returned
+            // a mock object on a final method)
+            doCallRealMethod().when(mContext).getSystemService(clazz);
+        }
+    }
+
     private Set<Range<Integer>> rangeSet(Range<Integer> ... ranges) {
         final Set<Range<Integer>> range = new ArraySet<>();
         for (Range<Integer> r : ranges) range.add(r);