blob: f6b31767c37ab036e5daea2499e68ad8b92b53fb [file] [log] [blame]
/*
* Copyright (C) 2015 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.renderscript.cts;
import android.renderscript.*;
import android.util.Log;
import java.util.ArrayList;
public class IntrinsicBLAS extends IntrinsicBase {
private ScriptIntrinsicBLAS mBLAS;
private ArrayList<Allocation> mMatrixS;
private final float alphaS = 1.0f;
private final float betaS = 1.0f;
private ArrayList<Allocation> mMatrixD;
private final double alphaD = 1.0;
private final double betaD = 1.0;
private ArrayList<Allocation> mMatrixC;
private final Float2 alphaC = new Float2(1.0f, 1.0f);
private final Float2 betaC = new Float2(1.0f, 1.0f);
private ArrayList<Allocation> mMatrixZ;
private final Double2 alphaZ = new Double2(1.0, 1.0);
private final Double2 betaZ = new Double2(1.0, 1.0);
private int[] mTranspose = {ScriptIntrinsicBLAS.NO_TRANSPOSE,
ScriptIntrinsicBLAS.TRANSPOSE,
ScriptIntrinsicBLAS.CONJ_TRANSPOSE,
0};
private int[] mUplo = {ScriptIntrinsicBLAS.UPPER,
ScriptIntrinsicBLAS.LOWER,
0};
private int[] mDiag = {ScriptIntrinsicBLAS.NON_UNIT,
ScriptIntrinsicBLAS.UNIT,
0};
private int[] mSide = {ScriptIntrinsicBLAS.LEFT,
ScriptIntrinsicBLAS.RIGHT,
0};
private int[] mInc = {0, 1, 2};
private int[] mK = {-1, 0, 1};
private int[] mDim = {1, 2, 3, 256};
@Override
protected void setUp() throws Exception {
super.setUp();
mBLAS = ScriptIntrinsicBLAS.create(mRS);
//now populate the test Matrixes and Vectors.
mMatrixS = new ArrayList<Allocation>();
mMatrixD = new ArrayList<Allocation>();
mMatrixC = new ArrayList<Allocation>();
mMatrixZ = new ArrayList<Allocation>();
for (int x : mDim) {
for (int y : mDim) {
mMatrixS.add(Allocation.createTyped(mRS, Type.createXY(mRS, Element.F32(mRS), x, y)));
mMatrixD.add(Allocation.createTyped(mRS, Type.createXY(mRS, Element.F64(mRS), x, y)));
mMatrixC.add(Allocation.createTyped(mRS, Type.createXY(mRS, Element.F32_2(mRS), x, y)));
mMatrixZ.add(Allocation.createTyped(mRS, Type.createXY(mRS, Element.F64_2(mRS), x, y)));
}
}
//also need Allocation with mismatch Element.
Allocation misAlloc = Allocation.createTyped(mRS, Type.createXY(mRS, Element.U8(mRS), 1, 1));
mMatrixS.add(misAlloc);
mMatrixD.add(misAlloc);
mMatrixC.add(misAlloc);
mMatrixZ.add(misAlloc);
}
@Override
protected void tearDown() throws Exception {
if (mBLAS != null) {
mBLAS.destroy();
mBLAS = null;
}
super.tearDown();
}
private boolean validateSide(int Side) {
if (Side != ScriptIntrinsicBLAS.LEFT && Side != ScriptIntrinsicBLAS.RIGHT) {
return false;
}
return true;
}
private boolean validateTranspose(int Trans) {
if (Trans != ScriptIntrinsicBLAS.NO_TRANSPOSE &&
Trans != ScriptIntrinsicBLAS.TRANSPOSE &&
Trans != ScriptIntrinsicBLAS.CONJ_TRANSPOSE) {
return false;
}
return true;
}
private boolean validateConjTranspose(int Trans) {
if (Trans != ScriptIntrinsicBLAS.NO_TRANSPOSE &&
Trans != ScriptIntrinsicBLAS.CONJ_TRANSPOSE) {
return false;
}
return true;
}
private boolean validateDiag(int Diag) {
if (Diag != ScriptIntrinsicBLAS.NON_UNIT &&
Diag != ScriptIntrinsicBLAS.UNIT) {
return false;
}
return true;
}
private boolean validateUplo(int Uplo) {
if (Uplo != ScriptIntrinsicBLAS.UPPER &&
Uplo != ScriptIntrinsicBLAS.LOWER) {
return false;
}
return true;
}
private boolean validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
if (!validateTranspose(TransA)) {
return false;
}
int M = A.getType().getY();
int N = A.getType().getX();
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = -1, expectedYDim = -1;
if (TransA == ScriptIntrinsicBLAS.NO_TRANSPOSE) {
expectedXDim = 1 + (N - 1) * incX;
expectedYDim = 1 + (M - 1) * incY;
} else {
expectedXDim = 1 + (M - 1) * incX;
expectedYDim = 1 + (N - 1) * incY;
}
if (X.getType().getX() != expectedXDim ||
Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xGEMV_API_test(int trans, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateGEMV(elemA, trans, matA, vecX, incX, vecY, incY)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SGEMV(trans, alphaS, matA, vecX, incX, betaS, vecY, incY);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DGEMV(trans, alphaD, matA, vecX, incX, betaD, vecY, incY);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CGEMV(trans, alphaC, matA, vecX, incX, betaC, vecY, incY);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZGEMV(trans, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SGEMV(trans, alphaS, matA, vecX, incX, betaS, vecY, incY);
fail("should throw RSRuntimeException for SGEMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DGEMV(trans, alphaD, matA, vecX, incX, betaD, vecY, incY);
fail("should throw RSRuntimeException for DGEMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CGEMV(trans, alphaC, matA, vecX, incX, betaC, vecY, incY);
fail("should throw RSRuntimeException for CGEMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZGEMV(trans, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
fail("should throw RSRuntimeException for ZGEMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xGEMV_API(ArrayList<Allocation> mMatrix) {
for (int trans : mTranspose) {
for (int incX : mInc) {
for (int incY : mInc) {
xGEMV_API_test(trans, incX, incY, mMatrix);
}
}
}
}
public void test_L2_SGEMV_API() {
L2_xGEMV_API(mMatrixS);
}
public void test_L2_DGEMV_API() {
L2_xGEMV_API(mMatrixD);
}
public void test_L2_CGEMV_API() {
L2_xGEMV_API(mMatrixC);
}
public void test_L2_ZGEMV_API() {
L2_xGEMV_API(mMatrixZ);
}
private void xGBMV_API_test(int trans, int KL, int KU, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateGEMV(elemA, trans, matA, vecX, incX, vecY, incY) && KU >= 0 && KL >= 0) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SGBMV(trans, KL, KU, alphaS, matA, vecX, incX, betaS, vecY, incY);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DGBMV(trans, KL, KU, alphaD, matA, vecX, incX, betaD, vecY, incY);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CGBMV(trans, KL, KU, alphaC, matA, vecX, incX, betaC, vecY, incY);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZGBMV(trans, KL, KU, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SGBMV(trans, KL, KU, alphaS, matA, vecX, incX, betaS, vecY, incY);
fail("should throw RSRuntimeException for SGBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DGBMV(trans, KL, KU, alphaD, matA, vecX, incX, betaD, vecY, incY);
fail("should throw RSRuntimeException for DGBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CGBMV(trans, KL, KU, alphaC, matA, vecX, incX, betaC, vecY, incY);
fail("should throw RSRuntimeException for CGBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZGBMV(trans, KL, KU, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
fail("should throw RSRuntimeException for ZGBMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xGBMV_API(ArrayList<Allocation> mMatrix) {
for (int trans : mTranspose) {
for (int incX : mInc) {
for (int incY : mInc) {
for (int K : mK) {
xGBMV_API_test(trans, K, K, incX, incY, mMatrix);
}
}
}
}
}
public void test_L2_SGBMV_API() {
L2_xGBMV_API(mMatrixS);
}
public void test_L2_DGBMV_API() {
L2_xGBMV_API(mMatrixD);
}
public void test_L2_CGBMV_API() {
L2_xGBMV_API(mMatrixC);
}
public void test_L2_ZGBMV_API() {
L2_xGBMV_API(mMatrixZ);
}
private void xHEMV_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR2(elemA, Uplo, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHEMV(Uplo, alphaC, matA, vecX, incX, betaC, vecY, incY);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHEMV(Uplo, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHEMV(Uplo, alphaC, matA, vecX, incX, betaC, vecY, incY);
fail("should throw RSRuntimeException for CHEMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHEMV(Uplo, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
fail("should throw RSRuntimeException for ZHEMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xHEMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xHEMV_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_CHEMV_API() {
L2_xHEMV_API(mMatrixC);
}
public void test_L2_ZHEMV_API() {
L2_xHEMV_API(mMatrixZ);
}
private void xHBMV_API_test(int Uplo, int K, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR2(elemA, Uplo, vecX, incX, vecY, incY, matA) && K >= 0) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHBMV(Uplo, K, alphaC, matA, vecX, incX, betaC, vecY, incY);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHBMV(Uplo, K, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHBMV(Uplo, K, alphaC, matA, vecX, incX, betaC, vecY, incY);
fail("should throw RSRuntimeException for CHBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHBMV(Uplo, K, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
fail("should throw RSRuntimeException for ZHBMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xHBMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int K : mK) {
for (int incX : mInc) {
for (int incY : mInc) {
xHBMV_API_test(Uplo, K, incX, incY, mMatrix);
}
}
}
}
}
public void test_L2_CHBMV_API() {
L2_xHBMV_API(mMatrixC);
}
public void test_L2_ZHBMV_API() {
L2_xHBMV_API(mMatrixZ);
}
private void xHPMV_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSPR2(elemA, Uplo, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHPMV(Uplo, alphaC, matA, vecX, incX, betaC, vecY, incY);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHPMV(Uplo, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHPMV(Uplo, alphaC, matA, vecX, incX, betaC, vecY, incY);
fail("should throw RSRuntimeException for CHPMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHPMV(Uplo, alphaZ, matA, vecX, incX, betaZ, vecY, incY);
fail("should throw RSRuntimeException for ZHPMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xHPMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xHPMV_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_CHPMV_API() {
L2_xHPMV_API(mMatrixC);
}
public void test_L2_ZHPMV_API() {
L2_xHPMV_API(mMatrixZ);
}
private boolean validateSYMV(Element e, int Uplo, Allocation A, Allocation X, int incX, Allocation Y, int incY) {
if (!validateUplo(Uplo)) {
return false;
}
int N = A.getType().getY();
if (A.getType().getX() != N) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e) ) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
int expectedYDim = 1 + (N - 1) * incY;
if (Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xSYMV_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYMV(elemA, Uplo, matA, vecX, incX, vecY, incY)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSYMV(Uplo, alphaS, matA, vecX, incX, betaS, vecY, incY);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSYMV(Uplo, alphaD, matA, vecX, incX, betaD, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSYMV(Uplo, alphaS, matA, vecX, incX, betaS, vecY, incY);
fail("should throw RSRuntimeException for SSYMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSYMV(Uplo, alphaD, matA, vecX, incX, betaD, vecY, incY);
fail("should throw RSRuntimeException for DSYMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xSYMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xSYMV_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_SSYMV_API() {
L2_xSYMV_API(mMatrixS);
}
public void test_L2_DSYMV_API() {
L2_xSYMV_API(mMatrixD);
}
private void xSBMV_API_test(int Uplo, int K, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYMV(elemA, Uplo, matA, vecX, incX, vecY, incY) && K >= 0) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSBMV(Uplo, K, alphaS, matA, vecX, incX, betaS, vecY, incY);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSBMV(Uplo, K, alphaD, matA, vecX, incX, betaD, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSBMV(Uplo, K, alphaS, matA, vecX, incX, betaS, vecY, incY);
fail("should throw RSRuntimeException for SSBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSBMV(Uplo, K, alphaD, matA, vecX, incX, betaD, vecY, incY);
fail("should throw RSRuntimeException for DSBMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xSBMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int K : mK) {
for (int incX : mInc) {
for (int incY : mInc) {
xSBMV_API_test(Uplo, K, incX, incY, mMatrix);
}
}
}
}
}
public void test_L2_SSBMV_API() {
L2_xSBMV_API(mMatrixS);
}
public void test_L2_DSBMV_API() {
L2_xSBMV_API(mMatrixD);
}
private boolean validateSPMV(Element e, int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) {
if (!validateUplo(Uplo)) {
return false;
}
if (!Ap.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
if (Ap.getType().getY() > 1) {
return false;
}
int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
return false;
}
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
int expectedYDim = 1 + (N - 1) * incY;
if (Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xSPMV_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSPMV(elemA, Uplo, matA, vecX, incX, vecY, incY)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSPMV(Uplo, alphaS, matA, vecX, incX, betaS, vecY, incY);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSPMV(Uplo, alphaD, matA, vecX, incX, betaD, vecY, incY);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSPMV(Uplo, alphaS, matA, vecX, incX, betaS, vecY, incY);
fail("should throw RSRuntimeException for SSPMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSPMV(Uplo, alphaD, matA, vecX, incX, betaD, vecY, incY);
fail("should throw RSRuntimeException for DSPMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xSPMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xSPMV_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_SSPMV_API() {
L2_xSPMV_API(mMatrixS);
}
public void test_L2_DSPMV_API() {
L2_xSPMV_API(mMatrixD);
}
private boolean validateTRMV(Element e, int Uplo, int TransA, int Diag, Allocation A, Allocation X, int incX) {
if (!validateUplo(Uplo)) {
return false;
}
if (!validateTranspose(TransA)) {
return false;
}
if (!validateDiag(Diag)) {
return false;
}
int N = A.getType().getY();
if (A.getType().getX() != N) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1) {
return false;
}
if (incX <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
return true;
}
private void xTRMV_API_test(int Uplo, int TransA, int Diag, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTRMV(elemA, Uplo, TransA, Diag, matA, vecX, incX)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STRMV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTRMV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTRMV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTRMV(Uplo, TransA, Diag, matA, vecX, incX);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STRMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for STRMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTRMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for DTRMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTRMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for CTRMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTRMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for ZTRMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xTRMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
for (int incX : mInc) {
xTRMV_API_test(Uplo, TransA, Diag, incX, mMatrix);
}
}
}
}
}
public void test_L2_STRMV_API() {
L2_xTRMV_API(mMatrixS);
}
public void test_L2_DTRMV_API() {
L2_xTRMV_API(mMatrixD);
}
public void test_L2_CTRMV_API() {
L2_xTRMV_API(mMatrixC);
}
public void test_L2_ZTRMV_API() {
L2_xTRMV_API(mMatrixZ);
}
private void xTBMV_API_test(int Uplo, int TransA, int Diag, int K, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTRMV(elemA, Uplo, TransA, Diag, matA, vecX, incX) && K >= 0) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for STBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for DTBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for CTBMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTBMV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for ZTBMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xTBMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
for (int K : mK) {
for (int incX : mInc) {
xTBMV_API_test(Uplo, TransA, Diag, K, incX, mMatrix);
}
}
}
}
}
}
public void test_L2_STBMV_API() {
L2_xTBMV_API(mMatrixS);
}
public void test_L2_DTBMV_API() {
L2_xTBMV_API(mMatrixD);
}
public void test_L2_CTBMV_API() {
L2_xTBMV_API(mMatrixC);
}
public void test_L2_ZTBMV_API() {
L2_xTBMV_API(mMatrixZ);
}
private boolean validateTPMV(Element e, int Uplo, int TransA, int Diag, Allocation Ap, Allocation X, int incX) {
if (!validateUplo(Uplo)) {
return false;
}
if (!validateTranspose(TransA)) {
return false;
}
if (!validateDiag(Diag)) {
return false;
}
if (!Ap.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1) {
return false;
}
if (Ap.getType().getY() > 1) {
return false;
}
int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
return false;
}
if (incX <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
return true;
}
private void xTPMV_API_test(int Uplo, int TransA, int Diag, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTPMV(elemA, Uplo, TransA, Diag, matA, vecX, incX)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STPMV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTPMV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTPMV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTPMV(Uplo, TransA, Diag, matA, vecX, incX);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STPMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for STPMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTPMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for DTPMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTPMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for CTPMV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTPMV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for ZTPMV");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xTPMV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
for (int incX : mInc) {
xTPMV_API_test(Uplo, TransA, Diag, incX, mMatrix);
}
}
}
}
}
public void test_L2_STPMV_API() {
L2_xTPMV_API(mMatrixS);
}
public void test_L2_DTPMV_API() {
L2_xTPMV_API(mMatrixD);
}
public void test_L2_CTPMV_API() {
L2_xTPMV_API(mMatrixC);
}
public void test_L2_ZTPMV_API() {
L2_xTPMV_API(mMatrixZ);
}
private void xTRSV_API_test(int Uplo, int TransA, int Diag, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTRMV(elemA, Uplo, TransA, Diag, matA, vecX, incX)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STRSV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTRSV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTRSV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTRSV(Uplo, TransA, Diag, matA, vecX, incX);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STRSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for STRSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTRSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for DTRSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTRSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for CTRSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTRSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for ZTRSV");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xTRSV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
for (int incX : mInc) {
xTRSV_API_test(Uplo, TransA, Diag, incX, mMatrix);
}
}
}
}
}
public void test_L2_STRSV_API() {
L2_xTRSV_API(mMatrixS);
}
public void test_L2_DTRSV_API() {
L2_xTRSV_API(mMatrixD);
}
public void test_L2_CTRSV_API() {
L2_xTRSV_API(mMatrixC);
}
public void test_L2_ZTRSV_API() {
L2_xTRSV_API(mMatrixZ);
}
private void xTBSV_API_test(int Uplo, int TransA, int Diag, int K, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTRMV(elemA, Uplo, TransA, Diag, matA, vecX, incX) && K >= 0) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for STBSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for DTBSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for CTBSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTBSV(Uplo, TransA, Diag, K, matA, vecX, incX);
fail("should throw RSRuntimeException for ZTBSV");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xTBSV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
for (int K : mK) {
for (int incX : mInc) {
xTBSV_API_test(Uplo, TransA, Diag, K, incX, mMatrix);
}
}
}
}
}
}
public void test_L2_STBSV_API() {
L2_xTBSV_API(mMatrixS);
}
public void test_L2_DTBSV_API() {
L2_xTBSV_API(mMatrixD);
}
public void test_L2_CTBSV_API() {
L2_xTBSV_API(mMatrixC);
}
public void test_L2_ZTBSV_API() {
L2_xTBSV_API(mMatrixZ);
}
private void xTPSV_API_test(int Uplo, int TransA, int Diag, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTPMV(elemA, Uplo, TransA, Diag, matA, vecX, incX)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STPSV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTPSV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTPSV(Uplo, TransA, Diag, matA, vecX, incX);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTPSV(Uplo, TransA, Diag, matA, vecX, incX);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STPSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for STPSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTPSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for DTPSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTPSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for CTPSV");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTPSV(Uplo, TransA, Diag, matA, vecX, incX);
fail("should throw RSRuntimeException for ZTPSV");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xTPSV_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
for (int incX : mInc) {
xTPSV_API_test(Uplo, TransA, Diag, incX, mMatrix);
}
}
}
}
}
public void test_L2_STPSV_API() {
L2_xTPSV_API(mMatrixS);
}
public void test_L2_DTPSV_API() {
L2_xTPSV_API(mMatrixD);
}
public void test_L2_CTPSV_API() {
L2_xTPSV_API(mMatrixC);
}
public void test_L2_ZTPSV_API() {
L2_xTPSV_API(mMatrixZ);
}
private boolean validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e) ) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
int M = A.getType().getY();
int N = A.getType().getX();
if (N < 1 || M < 1) {
return false;
}
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = 1 + (M - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
int expectedYDim = 1 + (N - 1) * incY;
if (Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xGER_API_test(int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateGER(elemA, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SGER(alphaS, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DGER(alphaD, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SGER(alphaS, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for SGER");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DGER(alphaD, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for DGER");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
private void L2_xGER_API(ArrayList<Allocation> mMatrix) {
for (int incX : mInc) {
for (int incY : mInc) {
xGERU_API_test(incX, incY, mMatrix);
}
}
}
public void test_L2_SGER_API() {
L2_xGER_API(mMatrixS);
}
public void test_L2_DGER_API() {
L2_xGER_API(mMatrixD);
}
private boolean validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
int M = A.getType().getY();
int N = A.getType().getX();
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = 1 + (M - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
int expectedYDim = 1 + (N - 1) * incY;
if (Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xGERU_API_test(int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateGERU(elemA, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CGERU(alphaC, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZGERU(alphaZ, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CGERU(alphaC, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for CGERU");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZGERU(alphaZ, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for ZGERU");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
private void L2_xGERU_API(ArrayList<Allocation> mMatrix) {
for (int incX : mInc) {
for (int incY : mInc) {
xGERU_API_test(incX, incY, mMatrix);
}
}
}
public void test_L2_CGERU_API() {
L2_xGERU_API(mMatrixC);
}
public void test_L2_ZGERU_API() {
L2_xGERU_API(mMatrixZ);
}
private void xGERC_API_test(int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateGERU(elemA, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CGERC(alphaC, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZGERC(alphaZ, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CGERC(alphaC, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for CGERC");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZGERC(alphaZ, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for ZGERC");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
private void L2_xGERC_API(ArrayList<Allocation> mMatrix) {
for (int incX : mInc) {
for (int incY : mInc) {
xGERC_API_test(incX, incY, mMatrix);
}
}
}
public void test_L2_CGERC_API() {
L2_xGERC_API(mMatrixC);
}
public void test_L2_ZGERC_API() {
L2_xGERC_API(mMatrixZ);
}
private void xHER_API_test(int Uplo, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR(elemA, Uplo, vecX, incX, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHER(Uplo, alphaS, vecX, incX, matA);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHER(Uplo, alphaD, vecX, incX, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHER(Uplo, alphaS, vecX, incX, matA);
fail("should throw RSRuntimeException for CHER");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHER(Uplo, alphaD, vecX, incX, matA);
fail("should throw RSRuntimeException for ZHER");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xHER_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
xHER_API_test(Uplo, incX, mMatrix);
}
}
}
public void test_L2_CHER_API() {
L2_xHER_API(mMatrixC);
}
public void test_L2_ZHER_API() {
L2_xHER_API(mMatrixZ);
}
private void xHPR_API_test(int Uplo, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSPR(elemA, Uplo, vecX, incX, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHPR(Uplo, alphaS, vecX, incX, matA);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHPR(Uplo, alphaD, vecX, incX, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHPR(Uplo, alphaS, vecX, incX, matA);
fail("should throw RSRuntimeException for CHPR");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHPR(Uplo, alphaD, vecX, incX, matA);
fail("should throw RSRuntimeException for ZHPR");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xHPR_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
xHPR_API_test(Uplo, incX, mMatrix);
}
}
}
public void test_L2_CHPR_API() {
L2_xHPR_API(mMatrixC);
}
public void test_L2_ZHPR_API() {
L2_xHPR_API(mMatrixZ);
}
private void xHER2_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR2(elemA, Uplo, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHER2(Uplo, alphaC, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHER2(Uplo, alphaZ, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHER2(Uplo, alphaC, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for CHER2");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHER2(Uplo, alphaZ, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for ZHER2");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xHER2_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xHER2_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_CHER2_API() {
L2_xHER2_API(mMatrixC);
}
public void test_L2_ZHER2_API() {
L2_xHER2_API(mMatrixZ);
}
private void xHPR2_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSPR2(elemA, Uplo, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHPR2(Uplo, alphaC, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHPR2(Uplo, alphaZ, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHPR2(Uplo, alphaC, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for CHPR2");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHPR2(Uplo, alphaZ, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for ZHPR2");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xHPR2_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xHPR2_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_CHPR2_API() {
L2_xHPR2_API(mMatrixC);
}
public void test_L2_ZHPR2_API() {
L2_xHPR2_API(mMatrixZ);
}
private boolean validateSYR(Element e, int Uplo, Allocation X, int incX, Allocation A) {
if (!validateUplo(Uplo)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e)) {
return false;
}
int N = A.getType().getX();
if (X.getType().getY() > 1) {
return false;
}
if (N != A.getType().getY()) {
return false;
}
if (incX <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
return true;
}
private void xSYR_API_test(int Uplo, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR(elemA, Uplo, vecX, incX, matA)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSYR(Uplo, alphaS, vecX, incX, matA);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSYR(Uplo, alphaD, vecX, incX, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSYR(Uplo, alphaS, vecX, incX, matA);
fail("should throw RSRuntimeException for SSYR");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSYR(Uplo, alphaD, vecX, incX, matA);
fail("should throw RSRuntimeException for DSYR");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xSYR_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
xSYR_API_test(Uplo, incX, mMatrix);
}
}
}
public void test_L2_SSYR_API() {
L2_xSYR_API(mMatrixS);
}
public void test_L2_DSYR_API() {
L2_xSYR_API(mMatrixD);
}
private boolean validateSPR(Element e, int Uplo, Allocation X, int incX, Allocation Ap) {
if (!validateUplo(Uplo)) {
return false;
}
if (!Ap.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1) {
return false;
}
if (Ap.getType().getY() > 1) {
return false;
}
int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
return false;
}
if (incX <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
if (X.getType().getX() != expectedXDim) {
return false;
}
return true;
}
private void xSPR_API_test(int Uplo, int incX, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSPR(elemA, Uplo, vecX, incX, matA)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSPR(Uplo, alphaS, vecX, incX, matA);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSPR(Uplo, alphaD, vecX, incX, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSPR(Uplo, alphaS, vecX, incX, matA);
fail("should throw RSRuntimeException for SSPR");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSPR(Uplo, alphaD, vecX, incX, matA);
fail("should throw RSRuntimeException for DSPR");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L2_xSPR_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
xSPR_API_test(Uplo, incX, mMatrix);
}
}
}
public void test_L2_SSPR_API() {
L2_xSPR_API(mMatrixS);
}
public void test_L2_DSPR_API() {
L2_xSPR_API(mMatrixD);
}
private boolean validateSYR2(Element e, int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) {
if (!validateUplo(Uplo)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
int N = A.getType().getX();
if (N != A.getType().getY()) {
return false;
}
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
int expectedYDim = 1 + (N - 1) * incY;
if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xSYR2_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR2(elemA, Uplo, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSYR2(Uplo, alphaS, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSYR2(Uplo, alphaD, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSYR2(Uplo, alphaS, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for SSYR2");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSYR2(Uplo, alphaD, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for DSYR2");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xSYR2_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xSYR2_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_SSYR2_API() {
L2_xSYR2_API(mMatrixS);
}
public void test_L2_DSYR2_API() {
L2_xSYR2_API(mMatrixD);
}
private boolean validateSPR2(Element e, int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) {
if (!validateUplo(Uplo)) {
return false;
}
if (!Ap.getType().getElement().isCompatible(e) ||
!X.getType().getElement().isCompatible(e) ||
!Y.getType().getElement().isCompatible(e)) {
return false;
}
if (X.getType().getY() > 1 || Y.getType().getY() > 1) {
return false;
}
if (Ap.getType().getY() > 1) {
return false;
}
int N = (int)Math.sqrt((double)Ap.getType().getX() * 2);
if (Ap.getType().getX() != ((N * (N+1)) / 2)) {
return false;
}
if (incX <= 0 || incY <= 0) {
return false;
}
int expectedXDim = 1 + (N - 1) * incX;
int expectedYDim = 1 + (N - 1) * incY;
if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) {
return false;
}
return true;
}
private void xSPR2_API_test(int Uplo, int incX, int incY, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation vecX : mMatrix) {
for (Allocation vecY : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSPR2(elemA, Uplo, vecX, incX, vecY, incY, matA)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSPR2(Uplo, alphaS, vecX, incX, vecY, incY, matA);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSPR2(Uplo, alphaD, vecX, incX, vecY, incY, matA);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSPR2(Uplo, alphaS, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for SSPR2");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSPR2(Uplo, alphaD, vecX, incX, vecY, incY, matA);
fail("should throw RSRuntimeException for DSPR2");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L2_xSPR2_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int incX : mInc) {
for (int incY : mInc) {
xSPR2_API_test(Uplo, incX, incY, mMatrix);
}
}
}
}
public void test_L2_SSPR2_API() {
L2_xSPR2_API(mMatrixS);
}
public void test_L2_DSPR2_API() {
L2_xSPR2_API(mMatrixD);
}
private boolean validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) {
int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
if ((A != null && !A.getType().getElement().isCompatible(e)) ||
(B != null && !B.getType().getElement().isCompatible(e)) ||
(C != null && !C.getType().getElement().isCompatible(e))) {
return false;
}
if (C == null) {
//since matrix C is used to store the result, it cannot be null.
return false;
}
cM = C.getType().getY();
cN = C.getType().getX();
if (Side == ScriptIntrinsicBLAS.RIGHT) {
if ((A == null && B != null) || (A != null && B == null)) {
return false;
}
if (B != null) {
bM = A.getType().getY();
bN = A.getType().getX();
}
if (A != null) {
aM = B.getType().getY();
aN = B.getType().getX();
}
} else {
if (A != null) {
if (TransA == ScriptIntrinsicBLAS.TRANSPOSE ||
TransA == ScriptIntrinsicBLAS.CONJ_TRANSPOSE ) {
aN = A.getType().getY();
aM = A.getType().getX();
} else {
aM = A.getType().getY();
aN = A.getType().getX();
}
}
if (B != null) {
if (TransB == ScriptIntrinsicBLAS.TRANSPOSE ||
TransB == ScriptIntrinsicBLAS.CONJ_TRANSPOSE ) {
bN = B.getType().getY();
bM = B.getType().getX();
} else {
bM = B.getType().getY();
bN = B.getType().getX();
}
}
}
if (A != null && B != null && C != null) {
if (aN != bM || aM != cM || bN != cN) {
return false;
}
} else if (A != null && C != null) {
// A and C only, for SYRK
if (cM != cN) {
return false;
}
if (aM != cM) {
return false;
}
} else if (A != null && B != null) {
// A and B only
if (aN != bM) {
return false;
}
}
return true;
}
private boolean validateL3_xGEMM(Element e, int TransA, int TransB, Allocation A, Allocation B, Allocation C) {
boolean result = true;
result &= validateTranspose(TransA);
result &= validateTranspose(TransB);
result &= validateL3(e, TransA, TransB, 0, A, B, C);
return result;
}
private void xGEMM_API_test(int transA, int transB, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateL3_xGEMM(elemA, transA, transB, matA, matB, matC)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SGEMM(transA, transB, alphaS, matA, matB, betaS, matC);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DGEMM(transA, transB, alphaD, matA, matB, betaD, matC);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CGEMM(transA, transB, alphaC, matA, matB, betaC, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZGEMM(transA, transB, alphaZ, matA, matB, betaZ, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SGEMM(transA, transB, alphaS, matA, matB, betaS, matC);
fail("should throw RSRuntimeException for SGEMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DGEMM(transA, transB, alphaD, matA, matB, betaD, matC);
fail("should throw RSRuntimeException for DGEMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CGEMM(transA, transB, alphaC, matA, matB, betaC, matC);
fail("should throw RSRuntimeException for CGEMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZGEMM(transA, transB, alphaZ, matA, matB, betaZ, matC);
fail("should throw RSRuntimeException for ZGEMM");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
private void L3_xGEMM_API(ArrayList<Allocation> mMatrix) {
for (int transA : mTranspose) {
for (int transB : mTranspose) {
xGEMM_API_test(transA, transB, mMatrix);
}
}
}
public void test_L3_SGEMM_API() {
L3_xGEMM_API(mMatrixS);
}
public void test_L3_DGEMM_API() {
L3_xGEMM_API(mMatrixD);
}
public void test_L3_CGEMM_API() {
L3_xGEMM_API(mMatrixC);
}
public void test_L3_ZGEMM_API() {
L3_xGEMM_API(mMatrixZ);
}
private boolean validateL3_xSYMM(Element e, int Side, int Uplo, Allocation A, Allocation B, Allocation C) {
boolean result = true;
result &= validateSide(Side);
result &= validateUplo(Uplo);
result &= validateL3(e, 0, 0, Side, A, B, C);
result &= (A.getType().getX() == A.getType().getY());
return result;
}
private void xSYMM_API_test(int Side, int Uplo, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateL3_xSYMM(elemA, Side, Uplo, matA, matB, matC)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSYMM(Side, Uplo, alphaS, matA, matB, betaS, matC);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSYMM(Side, Uplo, alphaD, matA, matB, betaD, matC);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CSYMM(Side, Uplo, alphaC, matA, matB, betaC, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZSYMM(Side, Uplo, alphaZ, matA, matB, betaZ, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSYMM(Side, Uplo, alphaS, matA, matB, betaS, matC);
fail("should throw RSRuntimeException for SSYMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSYMM(Side, Uplo, alphaD, matA, matB, betaD, matC);
fail("should throw RSRuntimeException for DSYMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CSYMM(Side, Uplo, alphaC, matA, matB, betaC, matC);
fail("should throw RSRuntimeException for CSYMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZSYMM(Side, Uplo, alphaZ, matA, matB, betaZ, matC);
fail("should throw RSRuntimeException for ZSYMM");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
private void L3_xSYMM_API(ArrayList<Allocation> mMatrix) {
for (int Side : mSide) {
for (int Uplo : mUplo) {
xSYMM_API_test(Side, Uplo, mMatrix);
}
}
}
public void test_L3_SSYMM_API() {
L3_xSYMM_API(mMatrixS);
}
public void test_L3_DSYMM_API() {
L3_xSYMM_API(mMatrixD);
}
public void test_L3_CSYMM_API() {
L3_xSYMM_API(mMatrixC);
}
public void test_L3_ZSYMM_API() {
L3_xSYMM_API(mMatrixZ);
}
private boolean validateHEMM(Element e, int Side, int Uplo, Allocation A, Allocation B, Allocation C) {
if (!validateSide(Side)) {
return false;
}
if (!validateUplo(Uplo)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e) ||
!C.getType().getElement().isCompatible(e)) {
return false;
}
// A must be square; can potentially be relaxed similar to TRSM
int adim = A.getType().getX();
if (adim != A.getType().getY()) {
return false;
}
if ((Side == ScriptIntrinsicBLAS.LEFT && adim != B.getType().getY()) ||
(Side == ScriptIntrinsicBLAS.RIGHT && adim != B.getType().getX())) {
return false;
}
if (B.getType().getX() != C.getType().getX() ||
B.getType().getY() != C.getType().getY()) {
return false;
}
return true;
}
private void xHEMM_API_test(int Side, int Uplo, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateHEMM(elemA, Side, Uplo, matA, matB, matC)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHEMM(Side, Uplo, alphaC, matA, matB, betaC, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHEMM(Side, Uplo, alphaZ, matA, matB, betaZ, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHEMM(Side, Uplo, alphaC, matA, matB, betaC, matC);
fail("should throw RSRuntimeException for CHEMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHEMM(Side, Uplo, alphaZ, matA, matB, betaZ, matC);
fail("should throw RSRuntimeException for ZHEMM");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L3_xHEMM_API(ArrayList<Allocation> mMatrix) {
for (int Side : mSide) {
for (int Uplo : mUplo) {
xHEMM_API_test(Side, Uplo, mMatrix);
}
}
}
public void test_L3_CHEMM_API() {
L3_xHEMM_API(mMatrixC);
}
public void test_L3_ZHEMM_API() {
L3_xHEMM_API(mMatrixZ);
}
private boolean validateL3_xSYRK(Element e, int Uplo, int Trans, Allocation A, Allocation C) {
boolean result = true;
result &= validateTranspose(Trans);
result &= validateUplo(Uplo);
result &= validateL3(e, Trans, 0, 0, A, null, C);
return result;
}
private void xSYRK_API_test(int Uplo, int Trans, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateL3_xSYRK(elemA, Uplo, Trans, matA, matC)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSYRK(Uplo, Trans, alphaS, matA, betaS, matC);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSYRK(Uplo, Trans, alphaD, matA, betaD, matC);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CSYRK(Uplo, Trans, alphaC, matA, betaC, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZSYRK(Uplo, Trans, alphaZ, matA, betaZ, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSYRK(Uplo, Trans, alphaS, matA, betaS, matC);
fail("should throw RSRuntimeException for SSYRK");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSYRK(Uplo, Trans, alphaD, matA, betaD, matC);
fail("should throw RSRuntimeException for DSYRK");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CSYRK(Uplo, Trans, alphaC, matA, betaC, matC);
fail("should throw RSRuntimeException for CSYRK");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZSYRK(Uplo, Trans, alphaZ, matA, betaZ, matC);
fail("should throw RSRuntimeException for ZSYRK");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L3_xSYRK_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int Trans : mTranspose) {
xSYRK_API_test(Uplo, Trans, mMatrix);
}
}
}
public void test_L3_SSYRK_API() {
L3_xSYRK_API(mMatrixS);
}
public void test_L3_DSYRK_API() {
L3_xSYRK_API(mMatrixD);
}
public void test_L3_CSYRK_API() {
L3_xSYRK_API(mMatrixC);
}
public void test_L3_ZSYRK_API() {
L3_xSYRK_API(mMatrixZ);
}
private boolean validateHERK(Element e, int Uplo, int Trans, Allocation A, Allocation C) {
if (!validateUplo(Uplo)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!C.getType().getElement().isCompatible(e)) {
return false;
}
if (!validateConjTranspose(Trans)) {
return false;
}
int cdim = C.getType().getX();
if (cdim != C.getType().getY()) {
return false;
}
if (Trans == ScriptIntrinsicBLAS.NO_TRANSPOSE) {
if (cdim != A.getType().getY()) {
return false;
}
} else {
if (cdim != A.getType().getX()) {
return false;
}
}
return true;
}
private void xHERK_API_test(int Uplo, int Trans, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateHERK(elemA, Uplo, Trans, matA, matC)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHERK(Uplo, Trans, alphaS, matA, betaS, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHERK(Uplo, Trans, alphaD, matA, betaD, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHERK(Uplo, Trans, alphaS, matA, betaS, matC);
fail("should throw RSRuntimeException for CHERK");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHERK(Uplo, Trans, alphaD, matA, betaD, matC);
fail("should throw RSRuntimeException for ZHERK");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L3_xHERK_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int Trans : mTranspose) {
xHERK_API_test(Uplo, Trans, mMatrix);
}
}
}
public void test_L3_CHERK_API() {
L3_xHERK_API(mMatrixC);
}
public void test_L3_ZHERK_API() {
L3_xHERK_API(mMatrixZ);
}
private boolean validateSYR2K(Element e, int Uplo, int Trans, Allocation A, Allocation B, Allocation C) {
if (!validateTranspose(Trans)) {
return false;
}
if (!validateUplo(Uplo)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e) ||
!C.getType().getElement().isCompatible(e)) {
return false;
}
int Cdim = -1;
// A is n x k if no transpose, k x n if transpose
// C is n x n
if (Trans == ScriptIntrinsicBLAS.TRANSPOSE) {
// check columns versus C
Cdim = A.getType().getX();
} else {
// check rows versus C
Cdim = A.getType().getY();
}
if (C.getType().getX() != Cdim || C.getType().getY() != Cdim) {
return false;
}
// A dims == B dims
if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
return false;
}
return true;
}
private void xSYR2K_API_test(int Uplo, int Trans, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateSYR2K(elemA, Uplo, Trans, matA, matB, matC)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.SSYR2K(Uplo, Trans, alphaS, matA, matB, betaS, matC);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DSYR2K(Uplo, Trans, alphaD, matA, matB, betaD, matC);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CSYR2K(Uplo, Trans, alphaC, matA, matB, betaC, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZSYR2K(Uplo, Trans, alphaZ, matA, matB, betaZ, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.SSYR2K(Uplo, Trans, alphaS, matA, matB, betaS, matC);
fail("should throw RSRuntimeException for SSYR2K");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DSYR2K(Uplo, Trans, alphaD, matA, matB, betaD, matC);
fail("should throw RSRuntimeException for DSYR2K");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CSYR2K(Uplo, Trans, alphaC, matA, matB, betaC, matC);
fail("should throw RSRuntimeException for CSYR2K");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZSYR2K(Uplo, Trans, alphaZ, matA, matB, betaZ, matC);
fail("should throw RSRuntimeException for ZSYR2K");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L3_xSYR2K_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int Trans : mTranspose) {
xSYR2K_API_test(Uplo, Trans, mMatrix);
}
}
}
public void test_L3_SSYR2K_API() {
L3_xSYR2K_API(mMatrixS);
}
public void test_L3_DSYR2K_API() {
L3_xSYR2K_API(mMatrixD);
}
public void test_L3_CSYR2K_API() {
L3_xSYR2K_API(mMatrixC);
}
public void test_L3_ZSYR2K_API() {
L3_xSYR2K_API(mMatrixZ);
}
private boolean validateHER2K(Element e, int Uplo, int Trans, Allocation A, Allocation B, Allocation C) {
if (!validateUplo(Uplo)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e) ||
!C.getType().getElement().isCompatible(e)) {
return false;
}
if (!validateConjTranspose(Trans)) {
return false;
}
int cdim = C.getType().getX();
if (cdim != C.getType().getY()) {
return false;
}
if (Trans == ScriptIntrinsicBLAS.NO_TRANSPOSE) {
if (A.getType().getY() != cdim) {
return false;
}
} else {
if (A.getType().getX() != cdim) {
return false;
}
}
if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) {
return false;
}
return true;
}
private void xHER2K_API_test(int Uplo, int Trans, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
for (Allocation matC : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateHER2K(elemA, Uplo, Trans, matA, matB, matC)) {
try {
if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CHER2K(Uplo, Trans, alphaC, matA, matB, betaS, matC);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZHER2K(Uplo, Trans, alphaZ, matA, matB, betaD, matC);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.CHER2K(Uplo, Trans, alphaC, matA, matB, betaS, matC);
fail("should throw RSRuntimeException for CHER2K");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZHER2K(Uplo, Trans, alphaZ, matA, matB, betaD, matC);
fail("should throw RSRuntimeException for ZHER2K");
} catch (RSRuntimeException e) {
}
}
}
}
}
}
public void L3_xHER2K_API(ArrayList<Allocation> mMatrix) {
for (int Uplo : mUplo) {
for (int Trans : mTranspose) {
xHER2K_API_test(Uplo, Trans, mMatrix);
}
}
}
public void test_L3_CHER2K_API() {
L3_xHER2K_API(mMatrixC);
}
public void test_L3_ZHER2K_API() {
L3_xHER2K_API(mMatrixZ);
}
private boolean validateTRMM(Element e, int Side, int Uplo, int TransA, int Diag, Allocation A, Allocation B) {
if (!validateSide(Side)) {
return false;
}
if (!validateUplo(Uplo)) {
return false;
}
if (!validateTranspose(TransA)) {
return false;
}
if (!validateDiag(Diag)) {
return false;
}
int aM = -1, aN = -1, bM = -1, bN = -1;
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e)) {
return false;
}
aM = A.getType().getY();
aN = A.getType().getX();
if (aM != aN) {
return false;
}
bM = B.getType().getY();
bN = B.getType().getX();
if (Side == ScriptIntrinsicBLAS.LEFT) {
if (aN != bM) {
return false;
}
} else {
if (bN != aM) {
return false;
}
}
return true;
}
private void xTRMM_API_test(int Side, int Uplo, int TransA, int Diag, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTRMM(elemA, Side, Uplo, TransA, Diag, matA, matB)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STRMM(Side, Uplo, TransA, Diag, alphaS, matA, matB);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTRMM(Side, Uplo, TransA, Diag, alphaD, matA, matB);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTRMM(Side, Uplo, TransA, Diag, alphaC, matA, matB);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTRMM(Side, Uplo, TransA, Diag, alphaZ, matA, matB);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STRMM(Side, Uplo, TransA, Diag, alphaS, matA, matB);
fail("should throw RSRuntimeException for STRMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTRMM(Side, Uplo, TransA, Diag, alphaD, matA, matB);
fail("should throw RSRuntimeException for DTRMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTRMM(Side, Uplo, TransA, Diag, alphaC, matA, matB);
fail("should throw RSRuntimeException for CTRMM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTRMM(Side, Uplo, TransA, Diag, alphaZ, matA, matB);
fail("should throw RSRuntimeException for ZTRMM");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L3_xTRMM_API(ArrayList<Allocation> mMatrix) {
for (int Side : mSide) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
xTRMM_API_test(Side, Uplo, TransA, Diag, mMatrix);
}
}
}
}
}
public void test_L3_STRMM_API() {
L3_xTRMM_API(mMatrixS);
}
public void test_L3_DTRMM_API() {
L3_xTRMM_API(mMatrixD);
}
public void test_L3_CTRMM_API() {
L3_xTRMM_API(mMatrixC);
}
public void test_L3_ZTRMM_API() {
L3_xTRMM_API(mMatrixZ);
}
private boolean validateTRSM(Element e, int Side, int Uplo, int TransA, int Diag, Allocation A, Allocation B) {
int adim = -1, bM = -1, bN = -1;
if (!validateSide(Side)) {
return false;
}
if (!validateTranspose(TransA)) {
return false;
}
if (!validateUplo(Uplo)) {
return false;
}
if (!validateDiag(Diag)) {
return false;
}
if (!A.getType().getElement().isCompatible(e) ||
!B.getType().getElement().isCompatible(e)) {
return false;
}
adim = A.getType().getX();
if (adim != A.getType().getY()) {
// this may be unnecessary, the restriction could potentially be relaxed
// A needs to contain at least that symmetric matrix but could theoretically be larger
// for now we assume adapters are sufficient, will reevaluate in the future
return false;
}
bM = B.getType().getY();
bN = B.getType().getX();
if (Side == ScriptIntrinsicBLAS.LEFT) {
// A is M*M
if (adim != bM) {
return false;
}
} else {
// A is N*N
if (adim != bN) {
return false;
}
}
return true;
}
private void xTRSM_API_test(int Side, int Uplo, int TransA, int Diag, ArrayList<Allocation> mMatrix) {
for (Allocation matA : mMatrix) {
for (Allocation matB : mMatrix) {
Element elemA = matA.getType().getElement();
if (validateTRSM(elemA, Side, Uplo, TransA, Diag, matA, matB)) {
try {
if (elemA.isCompatible(Element.F32(mRS))) {
mBLAS.STRSM(Side, Uplo, TransA, Diag, alphaS, matA, matB);
} else if (elemA.isCompatible(Element.F64(mRS))) {
mBLAS.DTRSM(Side, Uplo, TransA, Diag, alphaD, matA, matB);
} else if (elemA.isCompatible(Element.F32_2(mRS))) {
mBLAS.CTRSM(Side, Uplo, TransA, Diag, alphaC, matA, matB);
} else if (elemA.isCompatible(Element.F64_2(mRS))) {
mBLAS.ZTRSM(Side, Uplo, TransA, Diag, alphaZ, matA, matB);
}
} catch (RSRuntimeException e) {
fail("should NOT throw RSRuntimeException");
}
} else {
try {
mBLAS.STRSM(Side, Uplo, TransA, Diag, alphaS, matA, matB);
fail("should throw RSRuntimeException for STRSM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.DTRSM(Side, Uplo, TransA, Diag, alphaD, matA, matB);
fail("should throw RSRuntimeException for DTRSM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.CTRSM(Side, Uplo, TransA, Diag, alphaC, matA, matB);
fail("should throw RSRuntimeException for CTRSM");
} catch (RSRuntimeException e) {
}
try {
mBLAS.ZTRSM(Side, Uplo, TransA, Diag, alphaZ, matA, matB);
fail("should throw RSRuntimeException for ZTRSM");
} catch (RSRuntimeException e) {
}
}
}
}
}
public void L3_xTRSM_API(ArrayList<Allocation> mMatrix) {
for (int Side : mSide) {
for (int Uplo : mUplo) {
for (int TransA : mTranspose) {
for (int Diag : mDiag) {
xTRSM_API_test(Side, Uplo, TransA, Diag, mMatrix);
}
}
}
}
}
public void test_L3_STRSM_API() {
L3_xTRSM_API(mMatrixS);
}
public void test_L3_DTRSM_API() {
L3_xTRSM_API(mMatrixD);
}
public void test_L3_CTRSM_API() {
L3_xTRSM_API(mMatrixC);
}
public void test_L3_ZTRSM_API() {
L3_xTRSM_API(mMatrixZ);
}
}