Prefer wlan over wwan network

Prior to the change wwan network, the always available default network,
will always be used for sending MMS even if a wlan(better) network
available, causing user complaints when wwan coverage is poor.

In this change, MMS will be first attempted to be sent via wwan, while waiting for wlan
to be available. Once wlan available, we disconnect all wwan http
connections in order to retry on the newly available wlan network.

The enhancement is guarded by device config flag
"mms_enhancement_enabled" with default true.

Fix: 231972603
Test: Local manually reproduce in poor wwan connection area and observe
the improvement from end user POV.
Test: tester at b/294076938
Change-Id: I695add2d454581b3db70a77adbe73175b664f972
diff --git a/PREUPLOAD.cfg b/PREUPLOAD.cfg
new file mode 100644
index 0000000..e5f1877
--- /dev/null
+++ b/PREUPLOAD.cfg
@@ -0,0 +1,2 @@
+[Hook Scripts]
+checkstyle_hook = ${REPO_ROOT}/prebuilts/checkstyle/checkstyle.py --sha ${PREUPLOAD_COMMIT}
\ No newline at end of file
diff --git a/src/com/android/mms/service/MmsHttpClient.java b/src/com/android/mms/service/MmsHttpClient.java
index b54b1aa..7978a71 100644
--- a/src/com/android/mms/service/MmsHttpClient.java
+++ b/src/com/android/mms/service/MmsHttpClient.java
@@ -29,7 +29,9 @@
 import android.util.Base64;
 import android.util.Log;
 
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.mms.service.exception.MmsHttpException;
+import com.android.mms.service.exception.VoluntaryDisconnectMmsHttpException;
 
 import java.io.BufferedInputStream;
 import java.io.BufferedOutputStream;
@@ -49,11 +51,12 @@
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
-import com.android.internal.annotations.VisibleForTesting;
-
 /**
  * MMS HTTP client for sending and downloading MMS messages
  */
@@ -87,6 +90,11 @@
     private final Network mNetwork;
     private final ConnectivityManager mConnectivityManager;
 
+    /** Store all currently open connections, for potential voluntarily early disconnect. */
+    private final Set<HttpURLConnection> mAllUrlConnections = ConcurrentHashMap.newKeySet();
+    /** Flag indicating whether a disconnection is voluntary. */
+    private final AtomicBoolean mVoluntarilyDisconnectingConnections = new AtomicBoolean(false);
+
     /**
      * Constructor
      *
@@ -136,6 +144,7 @@
             maybeWaitForIpv4(requestId, url);
             // Now get the connection
             connection = (HttpURLConnection) mNetwork.openConnection(url, proxy);
+            if (connection != null) mAllUrlConnections.add(connection);
             connection.setDoInput(true);
             connection.setConnectTimeout(
                     mmsConfig.getInt(SmsManager.MMS_CONFIG_HTTP_SOCKET_TIMEOUT));
@@ -238,15 +247,46 @@
             LogUtil.e(requestId, "HTTP: invalid URL protocol " + redactedUrl, e);
             throw new MmsHttpException(0/*statusCode*/, "Invalid URL protocol " + redactedUrl, e);
         } catch (IOException e) {
-            LogUtil.e(requestId, "HTTP: IO failure", e);
-            throw new MmsHttpException(0/*statusCode*/, e);
+            if (mVoluntarilyDisconnectingConnections.get()) {
+                // If in the process of voluntarily disconnecting all connections, the exception
+                // is casted as VoluntaryDisconnectMmsHttpException to indicate this attempt is
+                // cancelled rather than failure.
+                LogUtil.d(requestId,
+                        "HTTP voluntarily disconnected due to WLAN network available");
+                throw new VoluntaryDisconnectMmsHttpException(0/*statusCode*/,
+                        "Expected disconnection due to WLAN network available");
+            } else {
+                LogUtil.e(requestId, "HTTP: IO failure ", e);
+                throw new MmsHttpException(0/*statusCode*/, e);
+            }
         } finally {
             if (connection != null) {
                 connection.disconnect();
+                mAllUrlConnections.remove(connection);
+                // If all connections are done disconnected, flag voluntary disconnection done if
+                // applicable.
+                if (mAllUrlConnections.isEmpty() && mVoluntarilyDisconnectingConnections
+                        .compareAndSet(/*expectedValue*/true, /*newValue*/false)) {
+                    LogUtil.d("All voluntarily disconnected connections are removed.");
+                }
             }
         }
     }
 
+    /**
+     * Voluntarily disconnect all Http URL connections. This will trigger
+     * {@link VoluntaryDisconnectMmsHttpException} to be thrown, to indicate voluntary disconnection
+     */
+    public void disconnectAllUrlConnections() {
+        LogUtil.d("Disconnecting all Url connections, size = " + mAllUrlConnections.size());
+        if (mAllUrlConnections.isEmpty()) return;
+        mVoluntarilyDisconnectingConnections.set(true);
+        for (HttpURLConnection connection : mAllUrlConnections) {
+            // TODO: An improvement is to check the writing/reading progress before disconnect.
+            connection.disconnect();
+        }
+    }
+
     private void maybeWaitForIpv4(final String requestId, final URL url) {
         // If it's a literal IPv4 address and we're on an IPv6-only network,
         // wait until IPv4 is available.
diff --git a/src/com/android/mms/service/MmsNetworkManager.java b/src/com/android/mms/service/MmsNetworkManager.java
index f21e510..1f872e0 100644
--- a/src/com/android/mms/service/MmsNetworkManager.java
+++ b/src/com/android/mms/service/MmsNetworkManager.java
@@ -16,6 +16,7 @@
 
 package com.android.mms.service;
 
+import android.annotation.NonNull;
 import android.content.BroadcastReceiver;
 import android.content.Context;
 import android.content.Intent;
@@ -26,7 +27,6 @@
 import android.net.NetworkInfo;
 import android.net.NetworkRequest;
 import android.net.TelephonyNetworkSpecifier;
-import android.os.Binder;
 import android.os.Handler;
 import android.os.Looper;
 import android.os.Message;
@@ -35,7 +35,6 @@
 import android.telephony.CarrierConfigManager;
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
-import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.telephony.PhoneConstants;
@@ -45,8 +44,10 @@
  * Manages the MMS network connectivity
  */
 public class MmsNetworkManager {
+    /** Device Config Keys */
     private static final String MMS_SERVICE_NETWORK_REQUEST_TIMEOUT_MILLIS =
             "mms_service_network_request_timeout_millis";
+    private static final String MMS_ENHANCEMENT_ENABLED = "mms_enhancement_enabled";
 
     // Default timeout used to call ConnectivityManager.requestNetwork if the
     // MMS_SERVICE_NETWORK_REQUEST_TIMEOUT_MILLIS flag is not set.
@@ -59,6 +60,8 @@
 
     /* Event created when receiving ACTION_CARRIER_CONFIG_CHANGED */
     private static final int EVENT_CARRIER_CONFIG_CHANGED = 1;
+    /** Event when a WLAN network newly available despite of the existing available one. */
+    private static final int EVENT_IWLAN_NETWORK_NEWLY_AVAILABLE = 2;
 
     private final Context mContext;
 
@@ -66,6 +69,8 @@
     // We need this when we unbind from it. This is also used to indicate if the
     // MMS network is available.
     private Network mNetwork;
+    /** Whether an Iwlan MMS network is available to use. */
+    private boolean mIsLastAvailableNetworkIwlan;
     // The current count of MMS requests that require the MMS network
     // If mMmsRequestCount is 0, we should release the MMS network.
     private int mMmsRequestCount;
@@ -116,6 +121,9 @@
                     // Reload mNetworkReleaseTimeoutMillis from CarrierConfigManager.
                     handleCarrierConfigChanged();
                     break;
+                case EVENT_IWLAN_NETWORK_NEWLY_AVAILABLE:
+                    onIwlanNetworkNewlyAvailable();
+                    break;
                 default:
                     LogUtil.e("MmsNetworkManager: ignoring message of unexpected type " + msg.what);
             }
@@ -187,6 +195,17 @@
         }
     };
 
+    /**
+     * Called when a WLAN network newly available. This new WLAN network should replace the
+     * existing network and retry sending traffic on this network.
+     */
+    private void onIwlanNetworkNewlyAvailable() {
+        if (mMmsHttpClient == null || mNetwork == null) return;
+        LogUtil.d("onIwlanNetworkNewlyAvailable net " + mNetwork.getNetId());
+        mMmsHttpClient.disconnectAllUrlConnections();
+        populateHttpClientWithCurrentNetwork();
+    }
+
     private void handleCarrierConfigChanged() {
         final CarrierConfigManager configManager =
                 (CarrierConfigManager)
@@ -230,8 +249,13 @@
             // onAvailable will always immediately be followed by a onCapabilitiesChanged. Check
             // network status here is enough.
             super.onCapabilitiesChanged(network, nc);
+            final NetworkInfo networkInfo = getConnectivityManager().getNetworkInfo(network);
+            // wlan network is preferred over wwan network, because its existence meaning it's
+            // recommended by QualifiedNetworksService.
+            final boolean isWlan = networkInfo != null
+                    && networkInfo.getSubtype() == TelephonyManager.NETWORK_TYPE_IWLAN;
             LogUtil.w("NetworkCallbackListener.onCapabilitiesChanged: network="
-                    + network + ", nc=" + nc);
+                    + network + ", isWlan=" + isWlan + ", nc=" + nc);
             synchronized (MmsNetworkManager.this) {
                 final boolean isAvailable =
                         nc.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED);
@@ -244,10 +268,18 @@
                     return;
                 }
 
-                // New available network
-                if (mNetwork == null && isAvailable) {
-                    mNetwork = network;
-                    MmsNetworkManager.this.notifyAll();
+                // Use new available network
+                if (isAvailable) {
+                    if (mNetwork == null) {
+                        mNetwork = network;
+                        MmsNetworkManager.this.notifyAll();
+                    } else if (mDeps.isMmsEnhancementEnabled()
+                            // Iwlan network newly available, try send MMS over the new network.
+                            && !mIsLastAvailableNetworkIwlan && isWlan) {
+                        mNetwork = network;
+                        mEventHandler.sendEmptyMessage(EVENT_IWLAN_NETWORK_NEWLY_AVAILABLE);
+                    }
+                    mIsLastAvailableNetworkIwlan = isWlan;
                 }
             }
         }
@@ -271,6 +303,11 @@
                     DEFAULT_MMS_SERVICE_NETWORK_REQUEST_TIMEOUT_MILLIS);
         }
 
+        public boolean isMmsEnhancementEnabled() {
+            return DeviceConfig.getBoolean(
+                    DeviceConfig.NAMESPACE_TELEPHONY, MMS_ENHANCEMENT_ENABLED, true);
+        }
+
         public int getAdditionalNetworkAcquireTimeoutMillis() {
             return ADDITIONAL_NETWORK_ACQUIRE_TIMEOUT_MILLIS;
         }
@@ -395,17 +432,22 @@
      * Release the MMS network when nobody is holding on to it.
      *
      * @param requestId          request ID for logging.
+     * @param canRelease         whether the request can be released. An early release of a request
+     *                           can result in unexpected network torn down, as that network is used
+     *                           for immediate retry.
      * @param shouldDelayRelease whether the release should be delayed for a carrier-configured
      *                           timeout (default 5 seconds), the regular use case is to delay this
      *                           for DownloadRequests to use the network for sending an
      *                           acknowledgement on the same network.
      */
-    public void releaseNetwork(final String requestId, final boolean shouldDelayRelease) {
+    public void releaseNetwork(final String requestId, final boolean canRelease,
+            final boolean shouldDelayRelease) {
         synchronized (this) {
             if (mMmsRequestCount > 0) {
                 mMmsRequestCount -= 1;
-                LogUtil.d(requestId, "MmsNetworkManager: release, count=" + mMmsRequestCount);
-                if (mMmsRequestCount < 1) {
+                LogUtil.d(requestId, "MmsNetworkManager: release, count=" + mMmsRequestCount
+                        + " canRelease=" + canRelease);
+                if (mMmsRequestCount < 1 && canRelease) {
                     if (shouldDelayRelease) {
                         // remove previously posted task and post a delayed task on the release
                         // handler to release the network
@@ -464,7 +506,7 @@
         mMmsHttpClient = null;
     }
 
-    private ConnectivityManager getConnectivityManager() {
+    private @NonNull ConnectivityManager getConnectivityManager() {
         if (mConnectivityManager == null) {
             mConnectivityManager = (ConnectivityManager) mContext.getSystemService(
                     Context.CONNECTIVITY_SERVICE);
@@ -480,15 +522,19 @@
     public MmsHttpClient getOrCreateHttpClient() {
         synchronized (this) {
             if (mMmsHttpClient == null) {
-                if (mNetwork != null) {
-                    // Create new MmsHttpClient for the current Network
-                    mMmsHttpClient = new MmsHttpClient(mContext, mNetwork, mConnectivityManager);
-                }
+                populateHttpClientWithCurrentNetwork();
             }
             return mMmsHttpClient;
         }
     }
 
+    // Create new MmsHttpClient for the current Network
+    private void populateHttpClientWithCurrentNetwork() {
+        if (mNetwork != null) {
+            mMmsHttpClient = new MmsHttpClient(mContext, mNetwork, mConnectivityManager);
+        }
+    }
+
     /**
      * Get the APN name for the active network
      *
diff --git a/src/com/android/mms/service/MmsRequest.java b/src/com/android/mms/service/MmsRequest.java
index 5cf2ba6..8e51896 100644
--- a/src/com/android/mms/service/MmsRequest.java
+++ b/src/com/android/mms/service/MmsRequest.java
@@ -16,6 +16,7 @@
 
 package com.android.mms.service;
 
+import android.annotation.NonNull;
 import android.app.Activity;
 import android.app.PendingIntent;
 import android.content.Context;
@@ -39,9 +40,12 @@
 import com.android.mms.service.exception.ApnException;
 import com.android.mms.service.exception.MmsHttpException;
 import com.android.mms.service.exception.MmsNetworkException;
+import com.android.mms.service.exception.VoluntaryDisconnectMmsHttpException;
 import com.android.mms.service.metrics.MmsStats;
 
 import java.util.UUID;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 
 /**
  * Base class for MMS requests. This has the common logic of sending/downloading MMS.
@@ -164,15 +168,14 @@
         byte[] response = null;
         int retryId = 0;
         currentState = MmsRequestState.PrepareForHttpRequest;
-        // TODO: add mms data channel check back to fast fail if no way to send mms,
-        // when telephony provides such API.
+        int attemptedTimes = 0;
         if (!prepareForHttpRequest()) { // Prepare request, like reading pdu data from user
             LogUtil.e(requestId, "Failed to prepare for request");
             result = SmsManager.MMS_ERROR_IO_ERROR;
         } else { // Execute
             long retryDelaySecs = 2;
             // Try multiple times of MMS HTTP request, depending on the error.
-            for (retryId = 0; retryId < RETRY_TIMES; retryId++) {
+            while (retryId < RETRY_TIMES) {
                 httpStatusCode = 0; // Clear for retry.
                 MonitorTelephonyCallback connectionStateCallback = new MonitorTelephonyCallback();
                 try {
@@ -181,32 +184,26 @@
                     networkManager.acquireNetwork(requestId);
                     final String apnName = networkManager.getApnName();
                     LogUtil.d(requestId, "APN name is " + apnName);
+                    ApnSettings apn = null;
+                    currentState = MmsRequestState.LoadingApn;
                     try {
-                        ApnSettings apn = null;
-                        currentState = MmsRequestState.LoadingApn;
-                        try {
-                            apn = ApnSettings.load(context, apnName, mSubId, requestId);
-                        } catch (ApnException e) {
-                            // If no APN could be found, fall back to trying without the APN name
-                            if (apnName == null) {
-                                // If the APN name was already null then don't need to retry
-                                throw (e);
-                            }
-                            LogUtil.i(requestId, "No match with APN name: "
-                                    + apnName + ", try with no name");
-                            apn = ApnSettings.load(context, null, mSubId, requestId);
+                        apn = ApnSettings.load(context, apnName, mSubId, requestId);
+                    } catch (ApnException e) {
+                        // If no APN could be found, fall back to trying without the APN name
+                        if (apnName == null) {
+                            // If the APN name was already null then don't need to retry
+                            throw (e);
                         }
-                        LogUtil.i(requestId, "Using " + apn.toString());
-                        currentState = MmsRequestState.DoingHttp;
-                        response = doHttp(context, networkManager, apn);
-                        result = Activity.RESULT_OK;
-                        // Success
-                        break;
-                    } finally {
-                        // Release the MMS network immediately except successful DownloadRequest.
-                        networkManager.releaseNetwork(requestId,
-                                this instanceof DownloadRequest && result == Activity.RESULT_OK);
+                        LogUtil.i(requestId, "No match with APN name: "
+                                + apnName + ", try with no name");
+                        apn = ApnSettings.load(context, null, mSubId, requestId);
                     }
+                    LogUtil.i(requestId, "Using " + apn.toString());
+                    currentState = MmsRequestState.DoingHttp;
+                    response = doHttp(context, networkManager, apn);
+                    result = Activity.RESULT_OK;
+                    // Success
+                    break;
                 } catch (ApnException e) {
                     LogUtil.e(requestId, "APN failure", e);
                     result = SmsManager.MMS_ERROR_INVALID_APN;
@@ -216,8 +213,12 @@
                     result = SmsManager.MMS_ERROR_UNABLE_CONNECT_MMS;
                     break;
                 } catch (MmsHttpException e) {
-                    LogUtil.e(requestId, "HTTP or network I/O failure", e);
-                    result = SmsManager.MMS_ERROR_HTTP_FAILURE;
+                    if (e instanceof VoluntaryDisconnectMmsHttpException) {
+                        result = Activity.RESULT_CANCELED;
+                    } else {
+                        LogUtil.e(requestId, "HTTP or network I/O failure", e);
+                        result = SmsManager.MMS_ERROR_HTTP_FAILURE;
+                    }
                     httpStatusCode = e.getStatusCode();
                     // Retry
                 } catch (Exception e) {
@@ -225,12 +226,43 @@
                     result = SmsManager.MMS_ERROR_UNSPECIFIED;
                     break;
                 } finally {
+                    // Don't release the MMS network if the last attempt was voluntarily
+                    // cancelled (due to better network available), because releasing the request
+                    // could result that network being torn down as it's thought to be useless.
+                    boolean canRelease = false;
+                    if (result != Activity.RESULT_CANCELED) {
+                        retryId++;
+                        canRelease = true;
+                    }
+                    // Otherwise, delay the release for successful download request.
+                    networkManager.releaseNetwork(requestId, canRelease,
+                            this instanceof DownloadRequest && result == Activity.RESULT_OK);
+
                     stopListeningToDataConnectionState(connectionStateCallback);
                 }
-                try {
-                    Thread.sleep(retryDelaySecs * 1000, 0/*nano*/);
-                } catch (InterruptedException e) {}
-                retryDelaySecs <<= 1;
+
+                // THEORETICALLY WOULDN'T OCCUR - PUTTING HERE AS A SAFETY NET.
+                // TODO: REMOVE WITH FLAG mms_enhancement_enabled after soaking enough time, V-QPR.
+                // Only possible if network kept disconnecting due to Activity.RESULT_CANCELED,
+                // causing retryId doesn't increase and thus stuck in the infinite loop.
+                // However, it's theoretically impossible because RESULT_CANCELED is only triggered
+                // when a WLAN network becomes newly available in addition to an existing network.
+                // Therefore, the WLAN network's own death cannot be triggered by RESULT_CANCELED,
+                // and thus must result in retryId++.
+                if (++attemptedTimes > RETRY_TIMES * 2) {
+                    LogUtil.e(requestId, "Retry is performed too many times");
+                    reportAnomaly("MMS retried too many times",
+                            UUID.fromString("038c9155-5daa-4515-86ae-aafdd33c1435"));
+                    break;
+                }
+
+                if (result != Activity.RESULT_CANCELED) {
+                    try { // Cool down retry if the previous attempt wasn't voluntarily cancelled.
+                        new CountDownLatch(1).await(retryDelaySecs, TimeUnit.SECONDS);
+                    } catch (InterruptedException e) { }
+                    // Double the cool down time if the next try fails again.
+                    retryDelaySecs <<= 1;
+                }
             }
         }
         processResult(context, result, response, httpStatusCode, /* handledByCarrierApp= */ false,
@@ -328,19 +360,25 @@
                 String message = "MMS failed";
                 LogUtil.i(this.toString(),
                         message + " with error: " + result + " httpStatus:" + httpStatusCode);
-                TelephonyManager telephonyManager =
-                        mContext.getSystemService(TelephonyManager.class)
-                                .createForSubscriptionId(mSubId);
-                AnomalyReporter.reportAnomaly(
-                        generateUUID(result, httpStatusCode),
-                        message,
-                        telephonyManager.getSimCarrierId());
+                reportAnomaly(message, generateUUID(result, httpStatusCode));
                 break;
             default:
                 break;
         }
     }
 
+    private void reportAnomaly(@NonNull String anomalyMsg, @NonNull UUID uuid) {
+        TelephonyManager telephonyManager =
+                mContext.getSystemService(TelephonyManager.class)
+                        .createForSubscriptionId(mSubId);
+        if (telephonyManager != null) {
+            AnomalyReporter.reportAnomaly(
+                    uuid,
+                    anomalyMsg,
+                    telephonyManager.getSimCarrierId());
+        }
+    }
+
     private UUID generateUUID(int result, int httpStatusCode) {
         long lresult = result;
         long lhttpStatusCode = httpStatusCode;
diff --git a/src/com/android/mms/service/exception/VoluntaryDisconnectMmsHttpException.java b/src/com/android/mms/service/exception/VoluntaryDisconnectMmsHttpException.java
new file mode 100644
index 0000000..2b6840a
--- /dev/null
+++ b/src/com/android/mms/service/exception/VoluntaryDisconnectMmsHttpException.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.mms.service.exception;
+
+/**
+ * Thrown when voluntarily disconnect an MMS http connection to trigger immediate retry. This
+ * exception indicates the connection is voluntarily cancelled, instead of a failure.
+ */
+public class VoluntaryDisconnectMmsHttpException extends MmsHttpException{
+    public VoluntaryDisconnectMmsHttpException(int statusCode, String message) {
+        super(statusCode, message);
+    }
+}
diff --git a/tests/robotests/src/com/android/mms/service/MmsNetworkManagerTest.java b/tests/robotests/src/com/android/mms/service/MmsNetworkManagerTest.java
index dff2bab..70b1a0a 100644
--- a/tests/robotests/src/com/android/mms/service/MmsNetworkManagerTest.java
+++ b/tests/robotests/src/com/android/mms/service/MmsNetworkManagerTest.java
@@ -22,13 +22,16 @@
 import static junit.framework.Assert.assertFalse;
 import static junit.framework.Assert.fail;
 
+import static org.junit.Assert.assertNotSame;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.verify;
+import static org.robolectric.RuntimeEnvironment.getMasterScheduler;
 
 import android.content.Context;
 import android.net.ConnectivityManager;
@@ -38,6 +41,7 @@
 import android.net.NetworkInfo;
 import android.os.PersistableBundle;
 import android.telephony.CarrierConfigManager;
+import android.telephony.TelephonyManager;
 
 import org.junit.Before;
 import org.junit.Test;
@@ -47,6 +51,7 @@
 import org.mockito.MockitoAnnotations;
 import org.robolectric.RobolectricTestRunner;
 
+import java.lang.reflect.Field;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -150,6 +155,29 @@
     }
 
     @Test
+    public void testAvailableNetwork_wwanNetworkReplacedByWlanNetwork() throws Exception {
+        MmsHttpClient mockMmsHttpClient = mock(MmsHttpClient.class);
+        doReturn(true).when(mDeps).isMmsEnhancementEnabled();
+
+        // WWAN network is always available, whereas WLAN network needs extra time to be set up.
+        doReturn(TelephonyManager.NETWORK_TYPE_LTE).when(mNetworkInfo).getSubtype();
+        doReturn(TelephonyManager.NETWORK_TYPE_IWLAN).when(mNetworkInfo2).getSubtype();
+
+        final NetworkCallback callback = acquireAvailableNetworkAndGetCallback(
+                mTestNetwork /* expectNetwork */, MMS_APN /* expectApn */);
+        replaceInstance(MmsNetworkManager.class, "mMmsHttpClient", mMnm,
+                mockMmsHttpClient);
+
+        // The WLAN network become available.
+        callback.onCapabilitiesChanged(mTestNetwork2, USABLE_NC);
+        getMasterScheduler().advanceToLastPostedRunnable();
+
+        // Verify current connections disconnect, then the client is replaced with a new network.
+        verify(mockMmsHttpClient).disconnectAllUrlConnections();
+        assertNotSame(mMnm.getOrCreateHttpClient(), mockMmsHttpClient);
+    }
+
+    @Test
     public void testAvailableNetwork_networkBecomeSuspend() throws Exception {
         final NetworkCallback callback = acquireAvailableNetworkAndGetCallback(
                 mTestNetwork /* expectNetwork */, MMS_APN /* expectApn */);
@@ -256,4 +284,12 @@
 
         return future;
     }
+
+    /** Helper to replace instance field with reflection. */
+    private void replaceInstance(final Class c, final String instanceName,
+            final Object obj, final Object newValue) throws Exception {
+        Field field = c.getDeclaredField(instanceName);
+        field.setAccessible(true);
+        field.set(obj, newValue);
+    }
 }
diff --git a/tests/unittests/src/com/android/mms/service/MmsHttpClientTest.java b/tests/unittests/src/com/android/mms/service/MmsHttpClientTest.java
index dd126e8..e144ef7 100644
--- a/tests/unittests/src/com/android/mms/service/MmsHttpClientTest.java
+++ b/tests/unittests/src/com/android/mms/service/MmsHttpClientTest.java
@@ -18,9 +18,14 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -29,22 +34,25 @@
 import android.net.ConnectivityManager;
 import android.net.Network;
 import android.os.Bundle;
-import android.telephony.ServiceState;
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 
 import androidx.test.core.app.ApplicationProvider;
 
+import com.android.mms.service.exception.VoluntaryDisconnectMmsHttpException;
+
 import org.junit.After;
 import org.junit.Before;
-import org.mockito.MockitoAnnotations;
-import org.mockito.Spy;
 import org.junit.Test;
+import org.mockito.MockitoAnnotations;
 
-import static org.mockito.ArgumentMatchers.anyInt;
-import static org.mockito.Mockito.reset;
-
-import android.util.Log;
+import java.io.IOException;
+import java.net.HttpURLConnection;
+import java.net.SocketException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 
 public class MmsHttpClientTest {
     // Mocked classes
@@ -145,4 +153,39 @@
         assertThat(phoneNo).contains(subscriberPhoneNumber);
         verify(mSubscriptionManager).getPhoneNumber(subId);
     }
+
+    @Test
+    public void testDisconnectAllUrlConnections() throws IOException {
+        Network mockNetwork = mock(Network.class);
+        HttpURLConnection mockConnection = mock(HttpURLConnection.class);
+        doReturn(mockConnection).when(mockNetwork).openConnection(any(), any());
+        doReturn(mockNetwork).when(mockNetwork).getPrivateDnsBypassingCopy();
+        ConnectivityManager mockCm = mock(ConnectivityManager.class);
+        Bundle config = new Bundle();
+
+        // The external thread that voluntarily silently close the socket.
+        CountDownLatch latch = new CountDownLatch(1);
+        final ExecutorService externalThread = Executors.newSingleThreadExecutor();
+        doAnswer(invok -> {
+            latch.countDown();
+            return null;
+        }).when(mockConnection).disconnect();
+
+        MmsHttpClient clientUT = new MmsHttpClient(mContext, mockNetwork, mockCm);
+        doAnswer(invok -> {
+            externalThread.execute(clientUT::disconnectAllUrlConnections);
+            // connection.disconnect is silent, but it will trigger SocketException thrown from the
+            // connect thread.
+            if (latch.await(1, TimeUnit.SECONDS)) {
+                throw new SocketException("Socket Closed");
+            }
+            return null;
+        }).when(mockConnection).getResponseCode();
+
+        // Verify SocketException is transformed into VoluntaryDisconnectMmsHttpException
+        assertThrows(VoluntaryDisconnectMmsHttpException.class, () -> {
+            clientUT.execute("http://test", new byte[0], "GET", false,
+                                "", 0, config, 1, "requestId");
+        });
+    }
 }