MediaServerCrashTest: add testDrmManagerClientReset.

Bug: 25070434
Change-Id: Iab68019ed12154175a899680d2d60eedfd91f2f1
diff --git a/tests/tests/security/AndroidManifest.xml b/tests/tests/security/AndroidManifest.xml
index d8d8dfa..3039f81 100644
--- a/tests/tests/security/AndroidManifest.xml
+++ b/tests/tests/security/AndroidManifest.xml
@@ -23,6 +23,7 @@
     <uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" />
     <uses-permission android:name="android.permission.INTERNET" />
     <uses-permission android:name="android.permission.MODIFY_AUDIO_SETTINGS" />
+    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
 
     <application>
         <uses-library android:name="android.test.runner" />
diff --git a/tests/tests/security/res/raw/drm_uaf.dm b/tests/tests/security/res/raw/drm_uaf.dm
new file mode 100644
index 0000000..50f42bc
--- /dev/null
+++ b/tests/tests/security/res/raw/drm_uaf.dm
Binary files differ
diff --git a/tests/tests/security/src/android/security/cts/MediaServerCrashTest.java b/tests/tests/security/src/android/security/cts/MediaServerCrashTest.java
index 4dc6783..9dff6dd 100644
--- a/tests/tests/security/src/android/security/cts/MediaServerCrashTest.java
+++ b/tests/tests/security/src/android/security/cts/MediaServerCrashTest.java
@@ -17,70 +17,259 @@
 package android.security.cts;
 
 import android.content.res.AssetFileDescriptor;
+import android.drm.DrmConvertedStatus;
+import android.drm.DrmManagerClient;
 import android.media.MediaPlayer;
 import android.os.ConditionVariable;
+import android.os.Environment;
+import android.os.ParcelFileDescriptor;
 import android.test.AndroidTestCase;
 import android.util.Log;
 
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.RandomAccessFile;
+
 import com.android.cts.security.R;
 
 public class MediaServerCrashTest extends AndroidTestCase {
     private static final String TAG = "MediaServerCrashTest";
 
-    public void testInvalidMidiNullPointerAccess() throws Exception {
-        testIfMediaServerDied(R.raw.midi_crash);
-    }
+    private static final String MIMETYPE_DRM_MESSAGE = "application/vnd.oma.drm.message";
 
-    private void testIfMediaServerDied(int res) throws Exception {
-        final MediaPlayer mediaPlayer = new MediaPlayer();
-        final ConditionVariable onPrepareCalled = new ConditionVariable();
-        final ConditionVariable onCompletionCalled = new ConditionVariable();
+    private String mFlFilePath;
 
-        onPrepareCalled.close();
-        onCompletionCalled.close();
-        mediaPlayer.setOnErrorListener(new MediaPlayer.OnErrorListener() {
+    private final MediaPlayer mMediaPlayer = new MediaPlayer();
+    private final ConditionVariable mOnPrepareCalled = new ConditionVariable();
+    private final ConditionVariable mOnCompletionCalled = new ConditionVariable();
+
+    @Override
+    protected void setUp() throws Exception {
+        super.setUp();
+        mFlFilePath = new File(Environment.getExternalStorageDirectory(),
+                "temp.fl").getAbsolutePath();
+
+        mOnPrepareCalled.close();
+        mOnCompletionCalled.close();
+        mMediaPlayer.setOnErrorListener(new MediaPlayer.OnErrorListener() {
             @Override
             public boolean onError(MediaPlayer mp, int what, int extra) {
-                assertTrue(mp == mediaPlayer);
+                assertTrue(mp == mMediaPlayer);
                 assertTrue("mediaserver process died", what != MediaPlayer.MEDIA_ERROR_SERVER_DIED);
                 Log.w(TAG, "onError " + what);
                 return false;
             }
         });
 
-        mediaPlayer.setOnPreparedListener(new MediaPlayer.OnPreparedListener() {
+        mMediaPlayer.setOnPreparedListener(new MediaPlayer.OnPreparedListener() {
             @Override
             public void onPrepared(MediaPlayer mp) {
-                assertTrue(mp == mediaPlayer);
-                onPrepareCalled.open();
+                assertTrue(mp == mMediaPlayer);
+                mOnPrepareCalled.open();
             }
         });
 
-        mediaPlayer.setOnCompletionListener(new MediaPlayer.OnCompletionListener() {
+        mMediaPlayer.setOnCompletionListener(new MediaPlayer.OnCompletionListener() {
             @Override
             public void onCompletion(MediaPlayer mp) {
-                assertTrue(mp == mediaPlayer);
-                onCompletionCalled.open();
+                assertTrue(mp == mMediaPlayer);
+                mOnCompletionCalled.open();
             }
         });
+    }
 
+    @Override
+    protected void tearDown() throws Exception {
+        super.tearDown();
+        File flFile = new File(mFlFilePath);
+        if (flFile.exists()) {
+            flFile.delete();
+        }
+    }
+
+    public void testInvalidMidiNullPointerAccess() throws Exception {
+        testIfMediaServerDied(R.raw.midi_crash);
+    }
+
+    private void testIfMediaServerDied(int res) throws Exception {
         AssetFileDescriptor afd = getContext().getResources().openRawResourceFd(res);
-        mediaPlayer.setDataSource(afd.getFileDescriptor(), afd.getStartOffset(), afd.getLength());
+        mMediaPlayer.setDataSource(afd.getFileDescriptor(), afd.getStartOffset(), afd.getLength());
         afd.close();
         try {
-            mediaPlayer.prepareAsync();
-            if (!onPrepareCalled.block(5000)) {
+            mMediaPlayer.prepareAsync();
+            if (!mOnPrepareCalled.block(5000)) {
                 Log.w(TAG, "testIfMediaServerDied: Timed out waiting for prepare");
                 return;
             }
-            mediaPlayer.start();
-            if (!onCompletionCalled.block(5000)) {
+            mMediaPlayer.start();
+            if (!mOnCompletionCalled.block(5000)) {
                 Log.w(TAG, "testIfMediaServerDied: Timed out waiting for Error/Completion");
             }
         } catch (Exception e) {
             Log.w(TAG, "playback failed", e);
         } finally {
-            mediaPlayer.release();
+            mMediaPlayer.release();
         }
     }
+
+    public void testDrmManagerClientReset() throws Exception {
+        checkIfMediaServerDiedForDrm(R.raw.drm_uaf);
+    }
+
+    private void checkIfMediaServerDiedForDrm(int res) throws Exception {
+        if (!convertDmToFl(res, mFlFilePath)) {
+            fail("Can not convert dm to fl");
+        }
+        Log.d(TAG, "intermediate fl file is " + mFlFilePath);
+
+        ParcelFileDescriptor flFd = null;
+        try {
+            flFd = ParcelFileDescriptor.open(new File(mFlFilePath),
+                    ParcelFileDescriptor.MODE_READ_ONLY);
+        } catch (FileNotFoundException e) {
+            fail("Could not find file: " + mFlFilePath +  e);
+        }
+
+        mMediaPlayer.setDataSource(flFd.getFileDescriptor(), 0, flFd.getStatSize());
+        flFd.close();
+        try {
+            mMediaPlayer.prepare();
+        } catch (Exception e) {
+            Log.d(TAG, "Prepare failed", e);
+        }
+
+        try {
+            mMediaPlayer.reset();
+            if (!mOnCompletionCalled.block(5000)) {
+                Log.w(TAG, "checkIfMediaServerDiedForDrm: Timed out waiting for Error/Completion");
+            }
+        } catch (Exception e) {
+            fail("reset failed" + e);
+        } finally {
+            mMediaPlayer.release();
+        }
+    }
+
+    private boolean convertDmToFl(int res, String flFilePath) throws Exception {
+        AssetFileDescriptor afd = getContext().getResources().openRawResourceFd(res);
+        FileInputStream inputStream = afd.createInputStream();
+        int inputLength = (int)afd.getLength();
+        byte[] fileData = new byte[inputLength];
+        int readSize = inputStream.read(fileData, 0, inputLength);
+        assertEquals("can not pull in all data", readSize, inputLength);
+        inputStream.close();
+        afd.close();
+
+        FileOutputStream flStream = new FileOutputStream(new File(flFilePath));
+
+        DrmManagerClient drmClient = null;
+        try {
+            drmClient = new DrmManagerClient(mContext);
+        } catch (IllegalArgumentException e) {
+            Log.w(TAG, "DrmManagerClient instance could not be created, context is Illegal.");
+            return false;
+        } catch (IllegalStateException e) {
+            Log.w(TAG, "DrmManagerClient didn't initialize properly.");
+            return false;
+        }
+
+        if (drmClient == null) {
+            Log.w(TAG, "Failed to create DrmManagerClient.");
+            return false;
+        }
+
+        int convertSessionId = -1;
+        try {
+            convertSessionId = drmClient.openConvertSession(MIMETYPE_DRM_MESSAGE);
+        } catch (IllegalArgumentException e) {
+            Log.w(TAG, "Conversion of Mimetype: " + MIMETYPE_DRM_MESSAGE
+                    + " is not supported.", e);
+            return false;
+        } catch (IllegalStateException e) {
+            Log.w(TAG, "Could not access Open DrmFramework.", e);
+            return false;
+        }
+
+        if (convertSessionId < 0) {
+            Log.w(TAG, "Failed to open session.");
+            return false;
+        }
+
+        DrmConvertedStatus convertedStatus = null;
+        try {
+            convertedStatus = drmClient.convertData(convertSessionId, fileData);
+        } catch (IllegalArgumentException e) {
+            Log.w(TAG, "Buffer with data to convert is illegal. Convertsession: "
+                    + convertSessionId, e);
+            return false;
+        } catch (IllegalStateException e) {
+            Log.w(TAG, "Could not convert data. Convertsession: " + convertSessionId, e);
+            return false;
+        }
+
+        if (convertedStatus == null ||
+                convertedStatus.statusCode != DrmConvertedStatus.STATUS_OK ||
+                convertedStatus.convertedData == null) {
+            Log.w(TAG, "Error in converting data. Convertsession: " + convertSessionId);
+            try {
+                drmClient.closeConvertSession(convertSessionId);
+            } catch (IllegalStateException e) {
+                Log.w(TAG, "Could not close session. Convertsession: " +
+                       convertSessionId, e);
+            }
+            return false;
+        }
+
+        flStream.write(convertedStatus.convertedData, 0, convertedStatus.convertedData.length);
+        flStream.close();
+
+        try {
+            convertedStatus = drmClient.closeConvertSession(convertSessionId);
+        } catch (IllegalStateException e) {
+            Log.w(TAG, "Could not close convertsession. Convertsession: " +
+                    convertSessionId, e);
+            return false;
+        }
+
+        if (convertedStatus == null ||
+                convertedStatus.statusCode != DrmConvertedStatus.STATUS_OK ||
+                convertedStatus.convertedData == null) {
+            Log.w(TAG, "Error in closing session. Convertsession: " + convertSessionId);
+            return false;
+        }
+
+        RandomAccessFile flRandomAccessFile = null;
+        try {
+            flRandomAccessFile = new RandomAccessFile(flFilePath, "rw");
+            flRandomAccessFile.seek(convertedStatus.offset);
+            flRandomAccessFile.write(convertedStatus.convertedData);
+        } catch (FileNotFoundException e) {
+            Log.w(TAG, "File: " + flFilePath + " could not be found.", e);
+            return false;
+        } catch (IOException e) {
+            Log.w(TAG, "Could not access File: " + flFilePath + " .", e);
+            return false;
+        } catch (IllegalArgumentException e) {
+            Log.w(TAG, "Could not open file in mode: rw", e);
+            return false;
+        } catch (SecurityException e) {
+            Log.w(TAG, "Access to File: " + flFilePath +
+                    " was denied denied by SecurityManager.", e);
+            return false;
+        } finally {
+            if (flRandomAccessFile != null) {
+                try {
+                    flRandomAccessFile.close();
+                } catch (IOException e) {
+                    Log.w(TAG, "Failed to close File:" + flFilePath + ".", e);
+                    return false;
+                }
+            }
+        }
+
+        return true;
+    }
 }