mediav2 CTS: Update tests to validate 10-bit support in media encoders

Bug: 210275761
Test: atest android.mediav2.cts.CodecEncoderValidationTest

Change-Id: Ib99ed17f30b3c38536f9f123d7ad1915b0f84cc4
diff --git a/tests/media/src/android/mediav2/cts/CodecEncoderValidationTest.java b/tests/media/src/android/mediav2/cts/CodecEncoderValidationTest.java
index f42f373..a0d90c7 100644
--- a/tests/media/src/android/mediav2/cts/CodecEncoderValidationTest.java
+++ b/tests/media/src/android/mediav2/cts/CodecEncoderValidationTest.java
@@ -22,6 +22,7 @@
 
 import androidx.test.filters.LargeTest;
 
+import org.junit.Assume;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -35,12 +36,15 @@
 import java.util.List;
 import java.util.Map;
 
+import static android.media.MediaCodecInfo.CodecCapabilities.COLOR_FormatYUVP010;
 import static android.mediav2.cts.CodecTestBase.SupportClass.*;
 import static org.junit.Assert.*;
 
 @RunWith(Parameterized.class)
 public class CodecEncoderValidationTest extends CodecEncoderTestBase {
     private static final String INPUT_AUDIO_FILE_HBD = "audio/sd_2ch_48kHz_f32le.raw";
+    private static final String INPUT_VIDEO_FILE_HBD = "dpov_352x288_30fps_yuv420p16le.yuv";
+
     private final boolean mUseHBD;
     private final SupportClass mSupportRequirements;
     // Key: mediaType, Value: tolerance duration in ms
@@ -69,7 +73,8 @@
         final boolean needVideo = true;
         final List<Object[]> defArgsList = Arrays.asList(new Object[][]{
                 // Audio tests covering cdd sec 5.1.3
-                // mediaType, arrays of bit-rates, sample rates, channel counts, usefloat
+                // mediaType, arrays of bit-rates, sample rates, channel counts, useHBD,
+                // SupportClass
                 {MediaFormat.MIMETYPE_AUDIO_AAC, new int[]{64000, 128000}, new int[]{8000, 12000,
                         16000, 22050, 24000, 32000, 44100, 48000}, new int[]{1, 2}, false,
                         CODEC_ALL},
@@ -88,7 +93,7 @@
                         new int[]{8000, 16000, 32000, 48000, 96000, 192000}, new int[]{1, 2},
                         true, CODEC_ALL},
 
-                // mediaType, arrays of bit-rates, width, height, Invalid Arg
+                // mediaType, arrays of bit-rates, width, height, useHBD, SupportClass
                 {MediaFormat.MIMETYPE_VIDEO_H263, new int[]{32000, 64000}, new int[]{176},
                         new int[]{144}, false, CODEC_ALL},
                 {MediaFormat.MIMETYPE_VIDEO_MPEG4, new int[]{32000, 64000}, new int[]{176},
@@ -108,6 +113,10 @@
     }
 
     void encodeAndValidate(String inputFile) throws IOException, InterruptedException {
+        if (!mIsAudio) {
+            int colorFormat = mFormats.get(0).getInteger(MediaFormat.KEY_COLOR_FORMAT);
+            Assume.assumeTrue(hasSupportForColorFormat(mCodecName, mMime, colorFormat));
+        }
         checkFormatSupport(mCodecName, mMime, true, mFormats, null, mSupportRequirements);
         setUpSource(inputFile);
         mOutputBuff = new OutputManager();
@@ -199,12 +208,22 @@
     @Test(timeout = PER_TEST_TIMEOUT_LARGE_TEST_MS)
     public void testEncodeAndValidate() throws IOException, InterruptedException {
         setUpParams(Integer.MAX_VALUE);
-        if (mIsAudio && mUseHBD) {
-            for (MediaFormat format : mFormats) {
-                format.setInteger(MediaFormat.KEY_PCM_ENCODING, AudioFormat.ENCODING_PCM_FLOAT);
+        String inputFile = mInputFile;
+        if (mUseHBD) {
+            if (mIsAudio) {
+                for (MediaFormat format : mFormats) {
+                    format.setInteger(MediaFormat.KEY_PCM_ENCODING, AudioFormat.ENCODING_PCM_FLOAT);
+                }
+                mBytesPerSample = 4;
+                inputFile = INPUT_AUDIO_FILE_HBD;
+            } else {
+                for (MediaFormat format : mFormats) {
+                    format.setInteger(MediaFormat.KEY_COLOR_FORMAT, COLOR_FormatYUVP010);
+                }
+                mBytesPerSample = 2;
+                inputFile = INPUT_VIDEO_FILE_HBD;
             }
-            mBytesPerSample = 4;
         }
-        encodeAndValidate(mIsAudio && mUseHBD ? INPUT_AUDIO_FILE_HBD : mInputFile);
+        encodeAndValidate(inputFile);
     }
 }
diff --git a/tests/media/src/android/mediav2/cts/CodecTestBase.java b/tests/media/src/android/mediav2/cts/CodecTestBase.java
index 67e25a5..2aaea0b 100644
--- a/tests/media/src/android/mediav2/cts/CodecTestBase.java
+++ b/tests/media/src/android/mediav2/cts/CodecTestBase.java
@@ -788,6 +788,22 @@
         return isSupported;
     }
 
+    static boolean hasSupportForColorFormat(String name, String mime, int colorFormat)
+            throws IOException {
+        MediaCodec codec = MediaCodec.createByCodecName(name);
+        MediaCodecInfo.CodecCapabilities cap =
+                codec.getCodecInfo().getCapabilitiesForType(mime);
+        boolean hasSupport = false;
+        for (int c : cap.colorFormats) {
+            if (c == colorFormat) {
+                hasSupport = true;
+                break;
+            }
+        }
+        codec.release();
+        return hasSupport;
+    }
+
     static boolean isDefaultCodec(String codecName, String mime, boolean isEncoder)
             throws IOException {
         Map<String,String> mDefaultCodecs = isEncoder ? mDefaultEncoders:  mDefaultDecoders;
@@ -1623,9 +1639,9 @@
         mMaxBFrames = 0;
         mChannels = 1;
         mSampleRate = 8000;
-        mBytesPerSample = 2;
         mAsyncHandle = new CodecAsyncHandler();
         mIsAudio = mMime.startsWith("audio/");
+        mBytesPerSample = mIsAudio ? 2 : 1;
         mInputFile = mIsAudio ? INPUT_AUDIO_FILE : INPUT_VIDEO_FILE;
     }
 
@@ -1680,7 +1696,12 @@
     }
 
     void fillImage(Image image) {
-        Assert.assertTrue(image.getFormat() == ImageFormat.YUV_420_888);
+        int format = image.getFormat();
+        assertTrue("unexpected image format",
+                format == ImageFormat.YUV_420_888 || format == ImageFormat.YCBCR_P010);
+        int bytesPerSample = (ImageFormat.getBitsPerPixel(format) * 2) / (8 * 3);  // YUV420
+        assertEquals("Invalid bytes per sample", bytesPerSample, mBytesPerSample);
+
         int imageWidth = image.getWidth();
         int imageHeight = image.getHeight();
         Image.Plane[] planes = image.getPlanes();
@@ -1699,17 +1720,18 @@
                 tileWidth = INP_FRM_WIDTH / 2;
                 tileHeight = INP_FRM_HEIGHT / 2;
             }
-            if (pixelStride == 1) {
+            if (pixelStride == bytesPerSample) {
                 if (width == rowStride && width == tileWidth && height == tileHeight) {
-                    buf.put(mInputData, offset, width * height);
+                    buf.put(mInputData, offset, width * height * bytesPerSample);
                 } else {
                     for (int z = 0; z < height; z += tileHeight) {
                         int rowsToCopy = Math.min(height - z, tileHeight);
                         for (int y = 0; y < rowsToCopy; y++) {
                             for (int x = 0; x < width; x += tileWidth) {
                                 int colsToCopy = Math.min(width - x, tileWidth);
-                                buf.position((z + y) * rowStride + x);
-                                buf.put(mInputData, offset + y * tileWidth, colsToCopy);
+                                buf.position((z + y) * rowStride + x * bytesPerSample);
+                                buf.put(mInputData, offset + y * tileWidth * bytesPerSample,
+                                        colsToCopy * bytesPerSample);
                             }
                         }
                     }
@@ -1723,14 +1745,17 @@
                         for (int x = 0; x < width; x += tileWidth) {
                             int colsToCopy = Math.min(width - x, tileWidth);
                             for (int w = 0; w < colsToCopy; w++) {
-                                buf.position(lineOffset + (x + w) * pixelStride);
-                                buf.put(mInputData[offset + y * tileWidth + w]);
+                                for (int bytePos = 0; bytePos < bytesPerSample; bytePos++) {
+                                    buf.position(lineOffset + (x + w) * pixelStride + bytePos);
+                                    buf.put(mInputData[offset + y * tileWidth * bytesPerSample +
+                                            w * bytesPerSample + bytePos]);
+                                }
                             }
                         }
                     }
                 }
             }
-            offset += tileWidth * tileHeight;
+            offset += tileWidth * tileHeight * bytesPerSample;
         }
     }
 
@@ -1752,13 +1777,15 @@
                 for (int j = 0; j < rowsToCopy; j++) {
                     for (int i = 0; i < width; i += tileWidth) {
                         int colsToCopy = Math.min(width - i, tileWidth);
-                        inputBuffer.position(offset + (k + j) * width + i);
-                        inputBuffer.put(mInputData, frmOffset + j * tileWidth, colsToCopy);
+                        inputBuffer.position(
+                                offset + (k + j) * width * mBytesPerSample + i * mBytesPerSample);
+                        inputBuffer.put(mInputData, frmOffset + j * tileWidth * mBytesPerSample,
+                                colsToCopy * mBytesPerSample);
                     }
                 }
             }
-            offset += width * height;
-            frmOffset += tileWidth * tileHeight;
+            offset += width * height * mBytesPerSample;
+            frmOffset += tileWidth * tileHeight * mBytesPerSample;
         }
     }
 
@@ -1782,8 +1809,8 @@
                 mNumBytesSubmitted += size;
             } else {
                 pts += mInputCount * 1000000L / mFrameRate;
-                size = mWidth * mHeight * 3 / 2;
-                int frmSize = INP_FRM_WIDTH * INP_FRM_HEIGHT * 3 / 2;
+                size = mBytesPerSample * mWidth * mHeight * 3 / 2;
+                int frmSize = mBytesPerSample * INP_FRM_WIDTH * INP_FRM_HEIGHT * 3 / 2;
                 if (mNumBytesSubmitted + frmSize > mInputData.length) {
                     fail("received partial frame to encode");
                 } else {