blob: b04209845694ab98e94bc97feeeeed36590d2a8f [file] [log] [blame]
R"(
/*
* Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
/*
* Copyright (c) 2016-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef ARM_COMPUTE_HELPER_H
#define ARM_COMPUTE_HELPER_H
/*
* Copyright (c) 2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
/** Store the 0 to (n-1)th rows of the given variables
* @name STORE_ROW_n
*
* @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @{
*/
#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
/** @} */ // end of groupd STORE_ROW_n
/** Convert and store the 0th to (n-1)th rows of the given variables
* @name CONVERT_STORE_ROW_n
*
* @param[in] N0 The size of the vectors
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @{
*/
#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE(N0) \
(CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
/** @} */ // end of groupd CONVERT_STORE_ROW_n
/** Store a block of the given size M0xN0
* @name STORE_BLOCK
*
* Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
* The data to store is expected to have consecutive names for each row.
* E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] M0 The number of rows to store
* @param[in] N0 The size of each vector
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @{
*/
#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
/** @} */ // end of group STORE_BLOCK
/** Convert and store a block of the given size M0xN0
* @name CONVERT_STORE_BLOCK
*
* Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
* The data to store is expected to have consecutive names for each row.
* E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] M0 The number of rows to store
* @param[in] N0 The size of each vector
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @{
*/
#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
/** @} */ // end of group CONVERT_STORE_BLOCK
/** Partially store the 0 to (n-1)th rows of the given variables
* @name STORE_ROW_PARTIAL_n
* Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
*
* @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
* @param[in] STORE_N0 The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @{
*/
#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
VSTORE_PARTIAL(N0, STORE_N0) \
(BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
/** @} */ // end of groupd STORE_ROW_PARTIAL_n
/** Partially store a block of the given size STORE_M0xSTORE_N0
* @name STORE_BLOCK_PARTIAL
*
* @note The vector width @p N0 is also required for correct partial storing behaviour.
* @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* The data to store is expected to have consecutive names for each row.
* E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] STORE_M0 The number of rows to store. Supported: 1-16
* @param[in] STORE_N0 The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
* @param[in] N0 The size of each vector. Supported: 1, 2, 3, 4, 8, 16
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @{
*/
#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
/** Store a block that can be partial in both x and y dimensions
*
* @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* The data to store is expected to have consecutive names for each row.
* E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
* @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
* @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
* @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
*/
#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y)) \
{ \
STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
} \
else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X)) \
{ \
STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
} \
else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X)) \
{ \
STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
} \
else \
{ \
STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** Store a block that can only be partial in x but not y.
*
* @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* The data to store is expected to have consecutive names for each row.
* E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
* @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
* @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
*/
#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
if(!(PARTIAL_COND_X)) \
{ \
STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
} \
else \
{ \
STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** Store a block that can only be partial in y but not x.
*
* @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* The data to store is expected to have consecutive names for each row.
* E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
* @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
* @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
*/
#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
if(!(PARTIAL_COND_Y)) \
{ \
STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
} \
else \
{ \
STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** @} */ // end of group STORE_BLOCK_PARTIAL
#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
/** Boundary-aware GEMM block store
* @name STORE_BLOCK_BOUNDARY_AWARE
* This macro assumes the following schemes to achieve boundary-awareness:
* - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
* - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
* - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
* The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
*
* In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
* blocks **at the end**.
* Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
* "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
*
* *--x--> x == 0 x == 1
* | |<------------------------------N-------------------------->|
* y |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
* | -------------#############################################################
* * | | |...............................|...........................|
* y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
* | | |...............................|...........................|
* M --#############################################################
* | | | |...........................|
* y == 1 | M0 | Non-boundary block |....Boundary block in x....|
* | | | |...........................|
* |------------#############################################################
*
* Then @p PARTIAL_STORE_M0 = M % M0 and @p PARTIAL_STORE_N0 = N % N0
*
* @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
* and select corresponding store methods such that the boundary detection logic is only added when needed.
*
* The data to store is expected to have consecutive names for each row.
* E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
* The Z offset is expected to have consecutive names.
* E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
*
* @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
* @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
* @param[in] DATA_TYPE The data type of the vectors
* @param[in] BASENAME The basename of the variables
* @param[in] PTR The base pointer
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
* @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
* @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
* @{
*/
#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
// Case1: No partial blocks in either x or y
#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
// Case2: Partial blocks in y
#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
// Case3: Partial blocks in x
#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
// Case4: Partial blocks in both x and y
#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
#endif // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
#if defined(PARTIAL_STORE_M0)
/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
* @name COMPUTE_M0_START_ROW
* If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
* This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
* blocks in the y dimension to avoid any padding.
* EG: M0=4, PARTIAL_STORE_M0=1:
* | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
* block 0 (partial)| start row = 0 | start row = 0
* block 1 (full) | start row = 4 | start row = 1
* block 2 (full) | start row = 8 | start row = 5
*
* @param[in] y Global id of current block in y.
* @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
* @{
*/
#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
#else // defined(PARTIAL_STORE_M0)
#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
((uint)(y * M0))
#endif // defined(PARTIAL_STORE_M0)
/** @} */ // end of group COMPUTE_M0_START_ROW
/** Store a vector that can only be partial in x.
*
* @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
* The data to store is expected to end in a 0.
* E.g., for basename=c, the expected name is c0.
*
* @param[in] basename The name of the variable without trailing 0
* @param[in] data_type The data type of the vector
* @param[in] ptr The base pointer
* @param[in] vec_size The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
* @param[in] leftover The vector size if cond = true. Supported range: [1, @p vec_size0)
* @param[in] cond Condition to select either vec_size0 or vec_size1
* @{
*/
#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
/** @} */ // end of group STORE_VECTOR_SELECT
#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
#pragma OPENCL EXTENSION cl_arm_printf : enable
#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
#define GPU_ARCH_MIDGARD 0x100
#define GPU_ARCH_BIFROST 0x200
/** Concatenate two inputs.
*
* @param[in] a The first input to be concatenated
* @param[in] b The second input to be concatenated
*
* @return The concatenated output
*/
#define CONCAT(a, b) a##b
/** Expand the given vector
*
* @param[in] x The vector to be expanded
*
* @return The expanded output
*/
#define EXPAND(x) x
/** Clamp the given value between an upper and lower bound.
*
* @param[in] x The value to be clamped
* @param[in] min_val The lower bound
* @param[in] max_val The upper bound
*
* @return The clamped value.
*/
#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
/** REVn reverses the given vector whose size is n.
* @name REVn
*
* @param[in] x The vector to be reversed
*
* @return The reversed vector
* @{
*/
#define REV1(x) ((x))
#define REV2(x) ((x).s10)
#define REV3(x) ((x).s210)
#define REV4(x) ((x).s3210)
#define REV8(x) ((x).s76543210)
#define REV16(x) ((x).sFEDCBA9876543210)
/** @} */ // end of group REVn
/** Reverse the given vector.
* @name REVERSE
*
* @param[in] x The vector to be reversed
* @param[in] s The size of the vector
*
* @return The reversed vector
* @{
*/
#define REVERSE_STR(x, s) REV##s((x))
#define REVERSE(x, s) REVERSE_STR(x, s)
/** @} */ // end of group REVERSE
/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
* @name ROTs_n
*
* @param[in] x The vector to be shifted
*
* @return The shifted vector
* @{
*/
#define ROT1_0(x) ((x))
#define ROT2_0(x) ((x))
#define ROT2_1(x) ((x).s10)
#define ROT3_0(x) ((x))
#define ROT3_1(x) ((x).s201)
#define ROT3_2(x) ((x).s120)
#define ROT4_0(x) ((x))
#define ROT4_1(x) ((x).s3012)
#define ROT4_2(x) ((x).s2301)
#define ROT4_3(x) ((x).s1230)
#define ROT8_0(x) ((x))
#define ROT8_1(x) ((x).s70123456)
#define ROT8_2(x) ((x).s67012345)
#define ROT8_3(x) ((x).s56701234)
#define ROT8_4(x) ((x).s45670123)
#define ROT8_5(x) ((x).s34567012)
#define ROT8_6(x) ((x).s23456701)
#define ROT8_7(x) ((x).s12345670)
#define ROT16_0(x) ((x))
#define ROT16_1(x) ((x).sF0123456789ABCDE)
#define ROT16_2(x) ((x).sEF0123456789ABCD)
#define ROT16_3(x) ((x).sDEF0123456789ABC)
#define ROT16_4(x) ((x).sCDEF0123456789AB)
#define ROT16_5(x) ((x).sBCDEF0123456789A)
#define ROT16_6(x) ((x).sABCDEF0123456789)
#define ROT16_7(x) ((x).s9ABCDEF012345678)
#define ROT16_8(x) ((x).s89ABCDEF01234567)
#define ROT16_9(x) ((x).s789ABCDEF0123456)
#define ROT16_10(x) ((x).s6789ABCDEF012345)
#define ROT16_11(x) ((x).s56789ABCDEF01234)
#define ROT16_12(x) ((x).s456789ABCDEF0123)
#define ROT16_13(x) ((x).s3456789ABCDEF012)
#define ROT16_14(x) ((x).s23456789ABCDEF01)
#define ROT16_15(x) ((x).s123456789ABCDEF0)
/** @} */ // end of group ROTs_n
/** Circular-right-shift (rotate-right) the given vector by the given amount.
* @name ROTATE
*
* @param[in] x The vector to be shifted
* @param[in] s The size of the vector
* @param[in] n The amount to be shifted
*
* @return The shifted vector
* @{
*/
#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
/** @} */ // end of group ROTATE
/** Creates a vector of size n filled with offset values corresponding to the location of each element.
* @name V_OFFSn
*
* @param[in] dt The data type of the output vector
*
* @return The vector filled with offset values
* @{
*/
#define V_OFFS1(dt) (dt##1)(0)
#define V_OFFS2(dt) (dt##2)(0, 1)
#define V_OFFS3(dt) (dt##3)(0, 1, 2)
#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
/** @} */ // end of group V_OFFSn
/** Create a vector filled with offset values corresponding to the location of each element.
* @name VEC_OFFS
*
* @param[in] dt The data type of the output vector
* @param[in] s The size of the output vector
*
* @return The vector filled with offset values
* @{
*/
#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
/** @} */ // end of group VEC_OFFS
#define VLOAD_STR(size) vload##size
#define VLOAD(size) VLOAD_STR(size)
#define PIXEL_UNIT4 1
#define PIXEL_UNIT8 2
#define PIXEL_UNIT16 4
/** Utility macro to convert a vector size in pixel unit.
*
* @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
*
* @param[in] vec_size Vector size. Only 4,8 and 16 is supported
*
* @return The pixel unit (number of pixels)
* @{
*/
#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
/** Utility macro to read a 2D OpenCL image object.
*
* @note Coordinates are not normalized
*
* @param[in] data_type Data type
* @param[in] n0 Number of pixel to read. Only 1,2 and 4 is supported
* @param[in] img OpenCL image object
* @param[in] x_coord The x coordinate for the top-left pixel
* @param[in] y_coord The y coordinate for the top-left pixel
*
* @return Pixels from the 2D OpenCL image object
* @{
*/
#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
#define VSTORE_STR(size) vstore##size
#define VSTORE(size) VSTORE_STR(size)
#define float1 float
#define half1 half
#define char1 char
#define uchar1 uchar
#define short1 short
#define ushort1 ushort
#define int1 int
#define uint1 uint
#define long1 long
#define ulong1 ulong
#define double1 double
#define vload1(OFFSET, PTR) *(OFFSET + PTR)
#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
/** Extended partial vstore that correctly handles scalar values as well.
* Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
* @name VSTORE_PARTIAL
*
* @note With this macro, the passed data can be both a vector and a scalar
* @note @p store_size needs to be <= @p size
* eg 1: Valid
* VSTORE_PARTIAL(16, 15) ...;
* eg 2: Invalid
* VSTORE_PARTIAL(4, 7) ...;
*
* @param[in] size The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
* @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
* @{
*/
#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
#define NO_STORE(data, offs, ptr) \
{ \
}
// Size == 1 (scalar)
#define vstore_partial_1_0 NO_STORE
#define vstore_partial_1_1 vstore1
#define vstore_partial_1_2 NO_STORE
#define vstore_partial_1_3 NO_STORE
#define vstore_partial_1_4 NO_STORE
#define vstore_partial_1_5 NO_STORE
#define vstore_partial_1_6 NO_STORE
#define vstore_partial_1_7 NO_STORE
#define vstore_partial_1_8 NO_STORE
#define vstore_partial_1_9 NO_STORE
#define vstore_partial_1_10 NO_STORE
#define vstore_partial_1_11 NO_STORE
#define vstore_partial_1_12 NO_STORE
#define vstore_partial_1_13 NO_STORE
#define vstore_partial_1_14 NO_STORE
#define vstore_partial_1_15 NO_STORE
#define vstore_partial_1_16 NO_STORE
// Size == 2
#define vstore_partial_2_0 NO_STORE
#define vstore_partial_2_1 vstore_partial_1
#define vstore_partial_2_2 vstore_partial_2
#define vstore_partial_2_3 NO_STORE
#define vstore_partial_2_4 NO_STORE
#define vstore_partial_2_5 NO_STORE
#define vstore_partial_2_6 NO_STORE
#define vstore_partial_2_7 NO_STORE
#define vstore_partial_2_8 NO_STORE
#define vstore_partial_2_9 NO_STORE
#define vstore_partial_2_10 NO_STORE
#define vstore_partial_2_11 NO_STORE
#define vstore_partial_2_12 NO_STORE
#define vstore_partial_2_13 NO_STORE
#define vstore_partial_2_14 NO_STORE
#define vstore_partial_2_15 NO_STORE
#define vstore_partial_2_16 NO_STORE
// Size == 3
#define vstore_partial_3_0 NO_STORE
#define vstore_partial_3_1 vstore_partial_1
#define vstore_partial_3_2 vstore_partial_2
#define vstore_partial_3_3 vstore_partial_3
#define vstore_partial_3_4 NO_STORE
#define vstore_partial_3_5 NO_STORE
#define vstore_partial_3_6 NO_STORE
#define vstore_partial_3_7 NO_STORE
#define vstore_partial_3_8 NO_STORE
#define vstore_partial_3_9 NO_STORE
#define vstore_partial_3_10 NO_STORE
#define vstore_partial_3_11 NO_STORE
#define vstore_partial_3_12 NO_STORE
#define vstore_partial_3_13 NO_STORE
#define vstore_partial_3_14 NO_STORE
#define vstore_partial_3_15 NO_STORE
#define vstore_partial_3_16 NO_STORE
// Size == 4
#define vstore_partial_4_0 NO_STORE
#define vstore_partial_4_1 vstore_partial_1
#define vstore_partial_4_2 vstore_partial_2
#define vstore_partial_4_3 vstore_partial_3
#define vstore_partial_4_4 vstore_partial_4
#define vstore_partial_4_5 NO_STORE
#define vstore_partial_4_6 NO_STORE
#define vstore_partial_4_7 NO_STORE
#define vstore_partial_4_8 NO_STORE
#define vstore_partial_4_9 NO_STORE
#define vstore_partial_4_10 NO_STORE
#define vstore_partial_4_11 NO_STORE
#define vstore_partial_4_12 NO_STORE
#define vstore_partial_4_13 NO_STORE
#define vstore_partial_4_14 NO_STORE
#define vstore_partial_4_15 NO_STORE
#define vstore_partial_4_16 NO_STORE
// Size == 8
#define vstore_partial_8_0 NO_STORE
#define vstore_partial_8_1 vstore_partial_1
#define vstore_partial_8_2 vstore_partial_2
#define vstore_partial_8_3 vstore_partial_3
#define vstore_partial_8_4 vstore_partial_4
#define vstore_partial_8_5 vstore_partial_5
#define vstore_partial_8_6 vstore_partial_6
#define vstore_partial_8_7 vstore_partial_7
#define vstore_partial_8_8 vstore_partial_8
#define vstore_partial_8_9 NO_STORE
#define vstore_partial_8_10 NO_STORE
#define vstore_partial_8_11 NO_STORE
#define vstore_partial_8_12 NO_STORE
#define vstore_partial_8_13 NO_STORE
#define vstore_partial_8_14 NO_STORE
#define vstore_partial_8_15 NO_STORE
#define vstore_partial_8_16 NO_STORE
// Size == 16
#define vstore_partial_16_0 NO_STORE
#define vstore_partial_16_1 vstore_partial_1
#define vstore_partial_16_2 vstore_partial_2
#define vstore_partial_16_3 vstore_partial_3
#define vstore_partial_16_4 vstore_partial_4
#define vstore_partial_16_5 vstore_partial_5
#define vstore_partial_16_6 vstore_partial_6
#define vstore_partial_16_7 vstore_partial_7
#define vstore_partial_16_8 vstore_partial_8
#define vstore_partial_16_9 vstore_partial_9
#define vstore_partial_16_10 vstore_partial_10
#define vstore_partial_16_11 vstore_partial_11
#define vstore_partial_16_12 vstore_partial_12
#define vstore_partial_16_13 vstore_partial_13
#define vstore_partial_16_14 vstore_partial_14
#define vstore_partial_16_15 vstore_partial_15
#define vstore_partial_16_16 vstore_partial_16
/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
* @name vstore_partial_n
*
* @note @p DATA needs to be a vector not a scalar
* @note n needs to be <= the vector width of the input variable @p DATA
* eg 1: Valid
* vstore_partial_15(var:float16, 0, 0xabcd);
* eg 2: Invalid
* vstore_partial_7(var:float4, 0, 0xabcd);
*
* @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
*
* @param[in] DATA The name of the variable
* @param[in] OFFSET Offset in n
* @param[in] PTR The base pointer
* @{
*/
#define vstore_partial_1(DATA, OFFSET, PTR) \
vstore1(DATA.s0, OFFSET, PTR);
#define vstore_partial_2(DATA, OFFSET, PTR) \
vstore2(DATA.s01, OFFSET, PTR);
#define vstore_partial_3(DATA, OFFSET, PTR) \
vstore3(DATA.s012, OFFSET, PTR);
#define vstore_partial_4(DATA, OFFSET, PTR) \
vstore4(DATA.s0123, OFFSET, PTR);
#define vstore_partial_5(DATA, OFFSET, PTR) \
vstore_partial_4(DATA.s0123, OFFSET, PTR); \
vstore1(DATA.s4, OFFSET, PTR + 4);
#define vstore_partial_6(DATA, OFFSET, PTR) \
vstore_partial_4(DATA.s0123, OFFSET, PTR); \
vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
#define vstore_partial_7(DATA, OFFSET, PTR) \
vstore_partial_4(DATA.s0123, OFFSET, PTR); \
vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
#define vstore_partial_8(DATA, OFFSET, PTR) \
vstore8(DATA.s01234567, OFFSET, PTR);
#define vstore_partial_9(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore1(DATA.s8, OFFSET, PTR + 8);
#define vstore_partial_10(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
#define vstore_partial_11(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
#define vstore_partial_12(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
#define vstore_partial_13(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
#define vstore_partial_14(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
#define vstore_partial_15(DATA, OFFSET, PTR) \
vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
#define vstore_partial_16(DATA, OFFSET, PTR) \
vstore16(DATA, OFFSET, PTR);
/** @} */ // end of groupd vstore_partial_n
/** @} */ // end of groupd VSTORE_PARTIAL
// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
// without _sat to overcome this issue
#define convert_float_sat convert_float
#define convert_float1_sat convert_float
#define convert_float2_sat convert_float2
#define convert_float3_sat convert_float3
#define convert_float4_sat convert_float4
#define convert_float8_sat convert_float8
#define convert_float16_sat convert_float16
#define convert_half_sat convert_float
#define convert_half1_sat convert_half
#define convert_half2_sat convert_half2
#define convert_half3_sat convert_half3
#define convert_half4_sat convert_half4
#define convert_half8_sat convert_half8
#define convert_half16_sat convert_half16
#define convert_float1 convert_float
#define convert_half1 convert_half
#define convert_char1 convert_char
#define convert_uchar1 convert_uchar
#define convert_short1 convert_short
#define convert_ushort1 convert_ushort
#define convert_int1 convert_int
#define convert_uint1 convert_uint
#define convert_long1 convert_long
#define convert_ulong1 convert_ulong
#define convert_double1 convert_double
#define convert_char1_sat convert_char_sat
#define convert_uchar1_sat convert_uchar_sat
#define convert_short1_sat convert_short_sat
#define convert_ushort1_sat convert_ushort_sat
#define convert_int1_sat convert_int_sat
#define convert_uint1_sat convert_uint_sat
#define convert_long1_sat convert_long_sat
#define convert_ulong1_sat convert_ulong_sat
#define convert_double1_sat convert_double_sat
#define VEC_DATA_TYPE_STR(type, size) type##size
#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
#define CONVERT_STR(x, type) (convert_##type((x)))
#define CONVERT(x, type) CONVERT_STR(x, type)
#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
#define select_vec_dt_uchar(size) uchar##size
#define select_vec_dt_char(size) char##size
#define select_vec_dt_ushort(size) ushort##size
#define select_vec_dt_short(size) short##size
#define select_vec_dt_half(size) short##size
#define select_vec_dt_uint(size) uint##size
#define select_vec_dt_int(size) int##size
#define select_vec_dt_float(size) int##size
#define select_vec_dt_ulong(size) ulong##size
#define select_vec_dt_long(size) long##size
#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
#define sum_reduce_1(x) (x)
#define sum_reduce_2(x) ((x).s0) + ((x).s1)
#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
#define max_reduce_1(x) (x)
#define max_reduce_2(x) max(((x).s0), ((x).s1))
#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
#define VECTOR_DECLARATION(name) \
__global uchar *name##_ptr, \
uint name##_stride_x, \
uint name##_step_x, \
uint name##_offset_first_element_in_bytes
#define IMAGE_DECLARATION(name) \
__global uchar *name##_ptr, \
uint name##_stride_x, \
uint name##_step_x, \
uint name##_stride_y, \
uint name##_step_y, \
uint name##_offset_first_element_in_bytes
#define TENSOR3D_DECLARATION(name) \
__global uchar *name##_ptr, \
uint name##_stride_x, \
uint name##_step_x, \
uint name##_stride_y, \
uint name##_step_y, \
uint name##_stride_z, \
uint name##_step_z, \
uint name##_offset_first_element_in_bytes
#define TENSOR4D_DECLARATION(name) \
__global uchar *name##_ptr, \
uint name##_stride_x, \
uint name##_step_x, \
uint name##_stride_y, \
uint name##_step_y, \
uint name##_stride_z, \
uint name##_step_z, \
uint name##_stride_w, \
uint name##_step_w, \
uint name##_offset_first_element_in_bytes
#define CONVERT_TO_VECTOR_STRUCT(name) \
update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
#define CONVERT_TO_IMAGE_STRUCT(name) \
update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
#define CONVERT_TO_TENSOR3D_STRUCT(name) \
update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
name##_stride_z, name##_step_z)
#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size) \
update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name) \
tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
name##_stride_z, name##_step_z)
/** Structure to hold Vector information */
typedef struct Vector
{
__global uchar *ptr; /**< Pointer to the starting postion of the buffer */
int offset_first_element_in_bytes; /**< The offset of the first element in the source image */
int stride_x; /**< Stride of the image in X dimension (in bytes) */
} Vector;
/** Structure to hold Image information */
typedef struct Image
{
__global uchar *ptr; /**< Pointer to the starting postion of the buffer */
int offset_first_element_in_bytes; /**< The offset of the first element in the source image */
int stride_x; /**< Stride of the image in X dimension (in bytes) */
int stride_y; /**< Stride of the image in Y dimension (in bytes) */
} Image;
/** Structure to hold 3D tensor information */
typedef struct Tensor3D
{
__global uchar *ptr; /**< Pointer to the starting postion of the buffer */
int offset_first_element_in_bytes; /**< The offset of the first element in the source image */
int stride_x; /**< Stride of the image in X dimension (in bytes) */
int stride_y; /**< Stride of the image in Y dimension (in bytes) */
int stride_z; /**< Stride of the image in Z dimension (in bytes) */
} Tensor3D;
/** Structure to hold 4D tensor information */
typedef struct Tensor4D
{
__global uchar *ptr; /**< Pointer to the starting postion of the buffer */
int offset_first_element_in_bytes; /**< The offset of the first element in the source image */
int stride_x; /**< Stride of the image in X dimension (in bytes) */
int stride_y; /**< Stride of the image in Y dimension (in bytes) */
int stride_z; /**< Stride of the image in Z dimension (in bytes) */
int stride_w; /**< Stride of the image in W dimension (in bytes) */
} Tensor4D;
/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
*
* @param[in] ptr Pointer to the starting postion of the buffer
* @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
* @param[in] stride_x Stride of the vector in X dimension (in bytes)
* @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes)
*
* @return An image object
*/
inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
{
Vector vector =
{
.ptr = ptr,
.offset_first_element_in_bytes = offset_first_element_in_bytes,
.stride_x = stride_x,
};
vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
return vector;
}
/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
*
* @param[in] ptr Pointer to the starting postion of the buffer
* @param[in] offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] stride_x Stride of the image in X dimension (in bytes)
* @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] stride_y Stride of the image in Y dimension (in bytes)
* @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes)
*
* @return An image object
*/
inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
{
Image img =
{
.ptr = ptr,
.offset_first_element_in_bytes = offset_first_element_in_bytes,
.stride_x = stride_x,
.stride_y = stride_y
};
img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
return img;
}
/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
*
* @param[in] ptr Pointer to the starting postion of the buffer
* @param[in] offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] stride_x Stride of the image in X dimension (in bytes)
* @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] stride_y Stride of the image in Y dimension (in bytes)
* @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] stride_z Stride of the image in Z dimension (in bytes)
* @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes)
*
* @return A 3D tensor object
*/
inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
{
Image img =
{
.ptr = ptr,
.offset_first_element_in_bytes = offset_first_element_in_bytes,
.stride_x = stride_x,
.stride_y = stride_y
};
img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
return img;
}
/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
*
* @param[in] ptr Pointer to the starting postion of the buffer
* @param[in] offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] stride_x Stride of the image in X dimension (in bytes)
* @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] stride_y Stride of the image in Y dimension (in bytes)
* @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] stride_z Stride of the image in Z dimension (in bytes)
* @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes)
*
* @return A 3D tensor object
*/
inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
{
Tensor3D tensor =
{
.ptr = ptr,
.offset_first_element_in_bytes = offset_first_element_in_bytes,
.stride_x = stride_x,
.stride_y = stride_y,
.stride_z = stride_z
};
tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
return tensor;
}
/** Wrap 3D tensor information into an tensor structure.
*
* @param[in] ptr Pointer to the starting postion of the buffer
* @param[in] offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] stride_x Stride of the image in X dimension (in bytes)
* @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] stride_y Stride of the image in Y dimension (in bytes)
* @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] stride_z Stride of the image in Z dimension (in bytes)
* @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes)
*
* @return A 3D tensor object
*/
inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
{
Tensor3D tensor =
{
.ptr = ptr,
.offset_first_element_in_bytes = offset_first_element_in_bytes,
.stride_x = stride_x,
.stride_y = stride_y,
.stride_z = stride_z
};
return tensor;
}
inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
uint step_w,
uint mod_size)
{
Tensor4D tensor =
{
.ptr = ptr,
.offset_first_element_in_bytes = offset_first_element_in_bytes,
.stride_x = stride_x,
.stride_y = stride_y,
.stride_z = stride_z,
.stride_w = stride_w
};
tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
return tensor;
}
/** Get the pointer position of a Vector
*
* @param[in] vec Pointer to the starting position of the buffer
* @param[in] x Relative X position
*/
inline __global const uchar *vector_offset(const Vector *vec, int x)
{
return vec->ptr + x * vec->stride_x;
}
/** Get the pointer position of a Image
*
* @param[in] img Pointer to the starting position of the buffer
* @param[in] x Relative X position
* @param[in] y Relative Y position
*/
inline __global uchar *offset(const Image *img, int x, int y)
{
return img->ptr + x * img->stride_x + y * img->stride_y;
}
/** Get the pointer position of a Tensor3D
*
* @param[in] tensor Pointer to the starting position of the buffer
* @param[in] x Relative X position
* @param[in] y Relative Y position
* @param[in] z Relative Z position
*/
inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
{
return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
}
/** Get the pointer position of a Tensor4D
*
* @param[in] tensor Pointer to the starting position of the buffer
* @param[in] x Relative X position
* @param[in] y Relative Y position
* @param[in] z Relative Z position
* @param[in] w Relative W position
*/
inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
{
return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
}
/** Get the offset for a given linear index of a Tensor3D
*
* @param[in] tensor Pointer to the starting position of the buffer
* @param[in] width Width of the input tensor
* @param[in] height Height of the input tensor
* @param[in] depth Depth of the input tensor
* @param[in] index Linear index
*/
inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
{
uint num_elements = width * height;
const uint z = index / num_elements;
index %= num_elements;
const uint y = index / width;
index %= width;
const uint x = index;
return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
}
#endif // _HELPER_H
#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \
({ \
basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \
basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \
basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \
basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \
basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s0) && (z_cond))); \
basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s1) && (z_cond))); \
})
#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \
({ \
basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \
basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \
basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \
basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \
basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s0))); \
basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s1))); \
})
#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \
({ \
basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \
basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \
basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \
basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \
basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s4) && (z_cond))); \
basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s5) && (z_cond))); \
basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s6) && (z_cond))); \
basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s7) && (z_cond))); \
})
#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \
({ \
basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \
basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \
basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \
basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \
basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s4))); \
basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s5))); \
basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s6))); \
basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s7))); \
})
#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
({ \
comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
comm_fact.s2 = 2.5f * tmp.s3; \
comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
\
out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
out.s1 = comm_fact.s0 + comm_fact.s1; \
out.s2 = comm_fact.s0 - comm_fact.s1; \
out.s3 = comm_fact.s3 + comm_fact.s4; \
out.s4 = comm_fact.s4 - comm_fact.s3; \
out.s5 = comm_fact.s5 + comm_fact.s6; \
out.s6 = comm_fact.s5 - comm_fact.s6; \
out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
})
#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \
({ \
comm_fact.s0 = 36.0f * tmp.s2 - 13.0f * tmp.s4 + tmp.s6; \
comm_fact.s1 = 36.0f * tmp.s1 - 13.0f * tmp.s3 + 1.0f * tmp.s5; \
comm_fact.s2 = 9.0f * tmp.s2 - 10.0f * tmp.s4 + tmp.s6; \
comm_fact.s3 = 18.0f * tmp.s1 - 20.0f * tmp.s3 + 2.0f * tmp.s5; \
comm_fact.s4 = 4.0f * tmp.s2 - 5.0f * tmp.s4 + tmp.s6; \
comm_fact.s5 = 12.0f * tmp.s1 - 15.0f * tmp.s3 + 3.0f * tmp.s5; \
out.s0 = -36.0f * tmp.s0 + 49.0f * tmp.s2 + -14.0f * tmp.s4 + tmp.s6; \
out.s1 = comm_fact.s0 - comm_fact.s1; \
out.s2 = comm_fact.s0 + comm_fact.s1; \
out.s3 = comm_fact.s2 - comm_fact.s3; \
out.s4 = comm_fact.s2 + comm_fact.s3; \
out.s5 = comm_fact.s4 - comm_fact.s5; \
out.s6 = comm_fact.s4 + comm_fact.s5; \
out.s7 = -36.0f * tmp.s1 + 0.0f * tmp.s2 + 49.0f * tmp.s3 - 14.0f * tmp.s5 + tmp.s7; \
})
#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
const int x = get_global_id(0);
const int y = get_global_id(1);
#if defined(SRC_DEPTH)
const int z = get_global_id(2) % SRC_DEPTH;
const int b = get_global_id(2) / SRC_DEPTH;
#else /* defined(SRC_DEPTH) */
const int z = get_global_id(2);
#endif /* defined(SRC_DEPTH) */
// Compute input address
#if defined(SRC_DEPTH)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
#endif /* defined(SRC_DEPTH) */
src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp0 = in_row0;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
tmp0 -= in_row2;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE out00 = tmp0.s0 - tmp0.s2;
DATA_TYPE out01 = tmp0.s1 + tmp0.s2;
DATA_TYPE out02 = tmp0.s2 - tmp0.s1;
DATA_TYPE out03 = tmp0.s1 - tmp0.s3;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp1 = in_row1 + in_row2;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp2 = in_row2 - in_row1;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp3 = in_row1 - in_row3;
DATA_TYPE out10 = tmp1.s0 - tmp1.s2;
DATA_TYPE out11 = tmp1.s1 + tmp1.s2;
DATA_TYPE out12 = tmp1.s2 - tmp1.s1;
DATA_TYPE out13 = tmp1.s1 - tmp1.s3;
DATA_TYPE out20 = tmp2.s0 - tmp2.s2;
DATA_TYPE out21 = tmp2.s1 + tmp2.s2;
DATA_TYPE out22 = tmp2.s2 - tmp2.s1;
DATA_TYPE out23 = tmp2.s1 - tmp2.s3;
DATA_TYPE out30 = tmp3.s0 - tmp3.s2;
DATA_TYPE out31 = tmp3.s1 + tmp3.s2;
DATA_TYPE out32 = tmp3.s2 - tmp3.s1;
DATA_TYPE out33 = tmp3.s1 - tmp3.s3;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
#if defined(SRC_DEPTH)
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
#endif /* defined(SRC_DEPTH) */
*((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
*((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
*((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
*((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
*((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out10;
*((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out11;
*((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out12;
*((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out13;
*((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out20;
*((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out21;
*((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out22;
*((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out23;
*((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out30;
*((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out31;
*((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out32;
*((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out33;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
const int x = get_global_id(0);
const int y = get_global_id(1);
#if defined(SRC_DEPTH)
const int z = (get_global_id(2) * 2) % SRC_DEPTH;
const int b = (get_global_id(2) * 2) / SRC_DEPTH;
#else /* defined(SRC_DEPTH) */
const int z = get_global_id(2) * 2;
#endif /* defined(SRC_DEPTH) */
// Compute input address
#if defined(SRC_DEPTH)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
#endif /* defined(SRC_DEPTH) */
src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
src_addr += src_stride_z;
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr));
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row5 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row6 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
in_row7 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp4 = in_row4;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
tmp0 -= in_row2;
tmp4 -= in_row6;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 2)
out00 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out02 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
VEC_DATA_TYPE(DATA_TYPE, 2)
out03 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp1 = in_row1 + in_row2;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp2 = in_row2 - in_row1;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp3 = in_row1 - in_row3;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp5 = in_row5 + in_row6;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp6 = in_row6 - in_row5;
VEC_DATA_TYPE(DATA_TYPE, 4)
tmp7 = in_row5 - in_row7;
VEC_DATA_TYPE(DATA_TYPE, 2)
out10 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out11 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out12 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
VEC_DATA_TYPE(DATA_TYPE, 2)
out13 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
VEC_DATA_TYPE(DATA_TYPE, 2)
out20 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out21 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out22 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
VEC_DATA_TYPE(DATA_TYPE, 2)
out23 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
VEC_DATA_TYPE(DATA_TYPE, 2)
out30 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out31 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
VEC_DATA_TYPE(DATA_TYPE, 2)
out32 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
VEC_DATA_TYPE(DATA_TYPE, 2)
out33 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
#if defined(SRC_DEPTH)
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
#endif /* defined(SRC_DEPTH) */
vstore2(out00, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z));
vstore2(out01, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z));
vstore2(out02, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z));
vstore2(out03, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z));
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
vstore2(out10, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z));
vstore2(out11, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z));
vstore2(out12, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z));
vstore2(out13, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z));
vstore2(out20, 0, (__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z));
vstore2(out21, 0, (__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z));
vstore2(out22, 0, (__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z));
vstore2(out23, 0, (__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z));
vstore2(out30, 0, (__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z));
vstore2(out31, 0, (__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z));
vstore2(out32, 0, (__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z));
vstore2(out33, 0, (__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z));
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
const int x = get_global_id(0);
const int y = get_global_id(1);
#if defined(SRC_DEPTH)
const int z = get_global_id(2) % SRC_DEPTH;
const int b = get_global_id(2) / SRC_DEPTH;
#else /* defined(SRC_DEPTH) */
const int z = get_global_id(2);
#endif /* defined(SRC_DEPTH) */
// Compute input address
#if defined(SRC_DEPTH)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
#endif /* defined(SRC_DEPTH) */
src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Row0
VEC_DATA_TYPE(DATA_TYPE, 4)
d00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
VEC_DATA_TYPE(DATA_TYPE, 2)
d01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)));
#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Row0
VEC_DATA_TYPE(DATA_TYPE, 4)
d00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 2)
d01 = vload2(2, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE out0 = 0.0f;
DATA_TYPE out1 = 0.0f;
DATA_TYPE out2 = 0.0f;
DATA_TYPE out3 = 0.0f;
DATA_TYPE out4 = 0.0f;
DATA_TYPE out5 = 0.0f;
// Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Row4
VEC_DATA_TYPE(DATA_TYPE, 4)
d40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 2)
d41 = vload2(2, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
// k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
DATA_TYPE k0 = d41.s0;
DATA_TYPE k1 = d41.s0;
DATA_TYPE k2 = d41.s0;
DATA_TYPE k3 = d41.s0;
DATA_TYPE k4 = d41.s0;
DATA_TYPE k5 = 0.0f;
k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
out0 += k0;
out1 += k1;
out2 += k2;
out3 += k3;
out4 += k4;
out5 += k5;
// Row2
VEC_DATA_TYPE(DATA_TYPE, 4)
d20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 2)
d21 = vload2(2, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Compute destination address
#if defined(SRC_DEPTH)
__global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
#else /* defined(SRC_DEPTH) */
__global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
#endif /* defined(SRC_DEPTH) */
uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
*(dst_addr) = out0;
dst_addr += dst_plane_stride;
*(dst_addr) = out1;
dst_addr += dst_plane_stride;
*(dst_addr) = out2;
dst_addr += dst_plane_stride;
*(dst_addr) = out3;
dst_addr += dst_plane_stride;
*(dst_addr) = out4;
dst_addr += dst_plane_stride;
*(dst_addr) = out5;
dst_addr += dst_plane_stride;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE out6 = k0;
DATA_TYPE out7 = k1;
DATA_TYPE out8 = k2;
DATA_TYPE out9 = k3;
DATA_TYPE out10 = k4;
DATA_TYPE out11 = k5;
DATA_TYPE out12 = k0;
DATA_TYPE out13 = k1;
DATA_TYPE out14 = k2;
DATA_TYPE out15 = k3;
DATA_TYPE out16 = k4;
DATA_TYPE out17 = k5;
DATA_TYPE out18 = k0;
DATA_TYPE out19 = k1;
DATA_TYPE out20 = k2;
DATA_TYPE out21 = k3;
DATA_TYPE out22 = k4;
DATA_TYPE out23 = k5;
DATA_TYPE out24 = k0;
DATA_TYPE out25 = k1;
DATA_TYPE out26 = k2;
DATA_TYPE out27 = k3;
DATA_TYPE out28 = k4;
DATA_TYPE out29 = k5;
// Row1
VEC_DATA_TYPE(DATA_TYPE, 4)
d10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 2)
d11 = vload2(2, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
// Row3
VEC_DATA_TYPE(DATA_TYPE, 4)
d30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 2)
d31 = vload2(2, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
// Compute common parts for the channels between [6, 29]
// Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
// Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
DATA_TYPE part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
DATA_TYPE part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
DATA_TYPE part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
DATA_TYPE part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
DATA_TYPE part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
DATA_TYPE part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
DATA_TYPE part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
DATA_TYPE part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
DATA_TYPE part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
DATA_TYPE part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
DATA_TYPE part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
DATA_TYPE part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
// Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
// Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
DATA_TYPE part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
DATA_TYPE part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
DATA_TYPE part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
DATA_TYPE part18 = part6 * 0.25f; // d20.s2 - d21.s0
DATA_TYPE part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
DATA_TYPE part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
DATA_TYPE part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
out6 += part0 - part1;
out12 += part0 + part1;
out7 += part2 + part3 + part4 + part5;
out8 += part2 - part3 + part4 - part5;
out13 += part2 + part3 - part4 - part5;
out14 += part2 - part3 - part4 + part5;
out9 += part6 + part7 + part8 + part9;
out10 += part6 - part7 + part8 - part9;
out15 += part6 - part7 - part8 + part9;
out16 += part6 + part7 - part8 - part9;
out11 += part10 + part11;
out17 += part10 - part11;
out18 += part13 - part12;
out24 += part13 + part12;
out19 += part14 + part15 + part16 + part17;
out20 += part14 - part15 + part16 - part17;
out25 += part14 - part15 - part16 + part17;
out26 += part14 + part15 - part16 - part17;
out21 += part18 + part19 + part20 + part21;
out22 += part18 - part19 + part20 - part21;
out27 += part18 - part19 - part20 + part21;
out28 += part18 + part19 - part20 - part21;
out23 += part22 + part23;
out29 += part22 - part23;
*(dst_addr) = out6;
dst_addr += dst_plane_stride;
*(dst_addr) = out7;
dst_addr += dst_plane_stride;
*(dst_addr) = out8;
dst_addr += dst_plane_stride;
*(dst_addr) = out9;
dst_addr += dst_plane_stride;
*(dst_addr) = out10;
dst_addr += dst_plane_stride;
*(dst_addr) = out11;
dst_addr += dst_plane_stride;
*(dst_addr) = out12;
dst_addr += dst_plane_stride;
*(dst_addr) = out13;
dst_addr += dst_plane_stride;
*(dst_addr) = out14;
dst_addr += dst_plane_stride;
*(dst_addr) = out15;
dst_addr += dst_plane_stride;
*(dst_addr) = out16;
dst_addr += dst_plane_stride;
*(dst_addr) = out17;
dst_addr += dst_plane_stride;
*(dst_addr) = out18;
dst_addr += dst_plane_stride;
*(dst_addr) = out19;
dst_addr += dst_plane_stride;
*(dst_addr) = out20;
dst_addr += dst_plane_stride;
*(dst_addr) = out21;
dst_addr += dst_plane_stride;
*(dst_addr) = out22;
dst_addr += dst_plane_stride;
*(dst_addr) = out23;
dst_addr += dst_plane_stride;
*(dst_addr) = out24;
dst_addr += dst_plane_stride;
*(dst_addr) = out25;
dst_addr += dst_plane_stride;
*(dst_addr) = out26;
dst_addr += dst_plane_stride;
*(dst_addr) = out27;
dst_addr += dst_plane_stride;
*(dst_addr) = out28;
dst_addr += dst_plane_stride;
*(dst_addr) = out29;
dst_addr += dst_plane_stride;
// Row5
VEC_DATA_TYPE(DATA_TYPE, 4)
d50 = vload4(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 2)
d51 = vload2(2, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
// Channels [30, 35]
out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
*(dst_addr) = out0;
dst_addr += dst_plane_stride;
*(dst_addr) = out1;
dst_addr += dst_plane_stride;
*(dst_addr) = out2;
dst_addr += dst_plane_stride;
*(dst_addr) = out3;
dst_addr += dst_plane_stride;
*(dst_addr) = out4;
dst_addr += dst_plane_stride;
*(dst_addr) = out5;
dst_addr += dst_plane_stride;
#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
const int x = get_global_id(0);
const int y = get_global_id(1);
#if defined(SRC_DEPTH)
const int z = get_global_id(2) % SRC_DEPTH;
const int b = get_global_id(2) / SRC_DEPTH;
#else /* defined(SRC_DEPTH) */
const int z = get_global_id(2);
#endif /* defined(SRC_DEPTH) */
// Compute input address
#if defined(SRC_DEPTH)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
#endif /* defined(SRC_DEPTH) */
src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
// Load input tile
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr));
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 6 * src_stride_y)),
*((__global DATA_TYPE *)(src_addr + 7 * src_stride_y)));
#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row1 = vload8(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row2 = vload8(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row3 = vload8(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row4 = vload8(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row5 = vload8(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row6 = vload8(0, (__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
const VEC_DATA_TYPE(DATA_TYPE, 8) in_row7 = vload8(0, (__global DATA_TYPE *)(src_addr + 7 * src_stride_y));
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
tmp0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = 0.0f;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25 * in_row4;
tmp0 += -in_row6 + (DATA_TYPE)5.25 * in_row4 - (DATA_TYPE)5.25 * in_row2;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25 * in_row3;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact2 = (DATA_TYPE)0.25 * in_row2 - (DATA_TYPE)1.25 * in_row4 + in_row6;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
comm_fact0 = (DATA_TYPE)2.5 * in_row3;
comm_fact1 = (DATA_TYPE)0.5 * in_row1 - comm_fact0 + (DATA_TYPE)2.0 * in_row5;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
comm_fact1 = (DATA_TYPE)2.0 * in_row1 - comm_fact0 + (DATA_TYPE)0.5 * in_row5;
comm_fact2 = (DATA_TYPE)4.0 * in_row2 - (DATA_TYPE)5.0 * in_row4 + in_row6;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25 * in_row3 - (DATA_TYPE)5.25 * in_row5;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Calculate output rows (reuse comm_fact0 vector)
VEC_DATA_TYPE(DATA_TYPE, 8)
out0;
OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
VEC_DATA_TYPE(DATA_TYPE, 8)
out1, out2, out3, out4, out5, out6, out7;
OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Store values across the channels
#if defined(SRC_DEPTH)
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
#else /* defined(SRC_DEPTH) */
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
#endif /* defined(SRC_DEPTH) */
*((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
*((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
*((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
*((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
*((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
*((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
*((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
*((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
*((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
*((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
*((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
*((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
*((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
*((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
*((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
*((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
*((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
*((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
*((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
*((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
*((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
*((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
*((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
*((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
*((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
*((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
*((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
*((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
*((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
*((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
*((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
*((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
*((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
*((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
*((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
*((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
*((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
*((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
*((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
*((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
*((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
*((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
*((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
*((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
*((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
*((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
*((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
*((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
*((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
*((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
*((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
*((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
*((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
*((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
*((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
*((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
*((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
*((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
*((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
*((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
*((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
*((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
*((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
*((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
* @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
// Index channel
const int x = get_global_id(0);
// Index width
const int y = get_global_id(1);
#if defined(NUM_TILES_Y)
// Index height
const int z = get_global_id(2) % NUM_TILES_Y;
// Index batch size
const int b = get_global_id(2) / NUM_TILES_Y;
#else // defined(NUM_TILES_Y)
// Index height
const int z = get_global_id(2);
#endif // defined(NUM_TILES_Y)
#if defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
#else // defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
#endif // defined(NUM_TILES_Y)
// Origin coordinates for the width (y) and height (z) in the input tensor
int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
int4 z_coord0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
int2 z_coord1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
// Coordinates to use to avoid out-of-bound reads
int4 y_coord_valid0 = clamp(y_coord0, (int4)0, (int4)((int)SRC_DIM_1 - 1));
int2 y_coord_valid1 = clamp(y_coord1, (int2)0, (int2)((int)SRC_DIM_1 - 1));
int4 z_coord_valid0 = clamp(z_coord0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
int2 z_coord_valid1 = clamp(z_coord1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
// Boundary conditions
int4 y_cond0 = y_coord_valid0 == y_coord0;
int2 y_cond1 = y_coord_valid1 == y_coord1;
int4 z_cond0 = z_coord_valid0 == z_coord0;
int2 z_cond1 = z_coord_valid1 == z_coord1;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d4, y_cond, z_cond1.s0);
DATA_TYPE k0 = d44;
DATA_TYPE k1 = d44;
DATA_TYPE k2 = d44;
DATA_TYPE k3 = d44;
DATA_TYPE k4 = d44;
DATA_TYPE k5 = (DATA_TYPE)0.0f;
k0 += 4.0f * d40 - 5.0f * d42;
k1 += -4.0f * d41 - 4.0f * d42 + d43;
k2 += 4.0f * d41 - 4.0f * d42 - d43;
k3 += -2.0f * d41 + 2.0f * d43 - d42;
k4 += 2.0f * d41 - 2.0f * d43 - d42;
k5 += 4.0f * d41 - 5.0f * d43 + d45;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d0, y_cond, z_cond0.s0);
#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(DATA_TYPE, d0, y_cond0.s0, z_cond);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04;
DATA_TYPE out1 = -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 4.0f * d04;
DATA_TYPE out2 = 16.0f * d01 - 16.0f * d02 - 4.0f * d03 + 4.0f * d04;
DATA_TYPE out3 = -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 4.0f * d04;
DATA_TYPE out4 = 8.0f * d01 - 4.0f * d02 - 8.0f * d03 + 4.0f * d04;
DATA_TYPE out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d2, y_cond, z_cond0.s2);
out0 += k0;
out1 += k1;
out2 += k2;
out3 += k3;
out4 += k4;
out5 += k5;
DATA_TYPE out6 = k0;
DATA_TYPE out7 = k1;
DATA_TYPE out8 = k2;
DATA_TYPE out9 = k3;
DATA_TYPE out10 = k4;
DATA_TYPE out11 = k5;
DATA_TYPE out12 = k0;
DATA_TYPE out13 = k1;
DATA_TYPE out14 = k2;
DATA_TYPE out15 = k3;
DATA_TYPE out16 = k4;
DATA_TYPE out17 = k5;
DATA_TYPE out18 = k0;
DATA_TYPE out19 = k1;
DATA_TYPE out20 = k2;
DATA_TYPE out21 = k3;
DATA_TYPE out22 = k4;
DATA_TYPE out23 = k5;
DATA_TYPE out24 = k0;
DATA_TYPE out25 = k1;
DATA_TYPE out26 = k2;
DATA_TYPE out27 = k3;
DATA_TYPE out28 = k4;
DATA_TYPE out29 = k5;
// Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
out0 += -20.0f * d20 + 25.0f * d22 - 5.0f * d24;
out1 += 20.0f * d21 + 20.0f * d22 - 5.0f * d23 - 5.0f * d24;
out2 += -20.0f * d21 + 20.0f * d22 + 5.0f * d23 - 5.0f * d24;
out3 += 10.0f * d21 + 5.0f * d22 - 10.0f * d23 - 5.0f * d24;
out4 += -10.0f * d21 + 5.0f * d22 + 10.0f * d23 - 5.0f * d24;
out5 += -20.0f * d21 + 25.0f * d23 - 5.0f * d25;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Compute destination address
#if defined(NUM_TILES_Y)
__global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
#else // defined(NUM_TILES_Y)
__global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
#endif // defined(NUM_TILES_Y)
uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
*((__global DATA_TYPE *)dst_addr) = out0;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out1;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out2;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out3;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out4;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out5;
dst_addr += dst_plane_stride;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d1, y_cond, z_cond0.s1);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d3, y_cond, z_cond0.s3);
// Compute common parts for the channels between [6, 29]
// Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
// Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
DATA_TYPE part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
DATA_TYPE part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
DATA_TYPE part2 = 16.0f * d22 - 4.0f * d24;
DATA_TYPE part3 = 16.0f * d21 - 4.0f * d23;
DATA_TYPE part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
DATA_TYPE part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
DATA_TYPE part6 = 4.0f * d22 - 4.0f * d24;
DATA_TYPE part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
DATA_TYPE part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
DATA_TYPE part9 = 8.0f * d21 - 8.0f * d23;
DATA_TYPE part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
DATA_TYPE part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
// Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
// Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
DATA_TYPE part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d22 - d24
DATA_TYPE part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
DATA_TYPE part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d21 - d23
DATA_TYPE part18 = part6 * 0.25f; // d22 - d24
DATA_TYPE part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
DATA_TYPE part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
DATA_TYPE part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
out6 += part0 - part1;
out12 += part0 + part1;
out7 += part2 + part3 + part4 + part5;
out8 += part2 - part3 + part4 - part5;
out13 += part2 + part3 - part4 - part5;
out14 += part2 - part3 - part4 + part5;
out9 += part6 + part7 + part8 + part9;
out10 += part6 - part7 + part8 - part9;
out15 += part6 - part7 - part8 + part9;
out16 += part6 + part7 - part8 - part9;
out11 += part10 + part11;
out17 += part10 - part11;
out18 += part13 - part12;
out24 += part13 + part12;
out19 += part14 + part15 + part16 + part17;
out20 += part14 - part15 + part16 - part17;
out25 += part14 - part15 - part16 + part17;
out26 += part14 + part15 - part16 - part17;
out21 += part18 + part19 + part20 + part21;
out22 += part18 - part19 + part20 - part21;
out27 += part18 - part19 - part20 + part21;
out28 += part18 + part19 - part20 - part21;
out23 += part22 + part23;
out29 += part22 - part23;
*((__global DATA_TYPE *)dst_addr) = out6;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out7;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out8;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out9;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out10;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out11;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out12;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out13;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out14;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out15;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out16;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out17;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out18;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out19;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out20;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out21;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out22;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out23;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out24;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out25;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out26;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out27;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out28;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out29;
dst_addr += dst_plane_stride;
// Row5
DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d5, y_cond, z_cond1.s1);
// Channels [30, 35]
out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
*((__global DATA_TYPE *)dst_addr) = out0;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out1;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out2;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out3;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out4;
dst_addr += dst_plane_stride;
*((__global DATA_TYPE *)dst_addr) = out5;
dst_addr += dst_plane_stride;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
* @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
const int x = get_global_id(0);
const int y = get_global_id(1);
#if defined(NUM_TILES_Y)
const int z = get_global_id(2) % NUM_TILES_Y;
const int b = get_global_id(2) / NUM_TILES_Y;
#else // defined(NUM_TILES_Y)
const int z = get_global_id(2);
#endif // defined(NUM_TILES_Y)
// Compute input address
#if defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
#else // defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
#endif // defined(NUM_TILES_Y)
// Origin coordinates for the width (y) and height (z) in the input tensor
int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
// Coordinates to use to avoid out-of-bound reads
int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1));
int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1));
// Boundary conditions
int8 y_cond0 = y_coord_valid0 == y_coord0;
int8 z_cond0 = z_coord_valid0 == z_coord0;
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = 0.0f;
VEC_DATA_TYPE(DATA_TYPE, 8)
tmp0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond);
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = 0.0f;
VEC_DATA_TYPE(DATA_TYPE, 8)
tmp0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
// Row0
in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
// Row1
in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1);
// Row2
in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2);
// Row3
in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3);
// Row4
in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4);
// Row5
in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5);
// Row6
in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
// Row7
in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7);
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
// Calculate intermediate tensor and reuse common factor vectors
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = in_row0 - in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
comm_fact0 = (DATA_TYPE)2.5f * in_row3;
comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.f * in_row5;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
comm_fact1 = (DATA_TYPE)2.f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
comm_fact2 = (DATA_TYPE)4.f * in_row2 - (DATA_TYPE)5.f * in_row4 + in_row6;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
// Calculate output rows (reuse comm_fact0 vector)
VEC_DATA_TYPE(DATA_TYPE, 8)
out0, out1, out2, out3, out4, out5, out6, out7;
OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Store values across the channels
#if defined(NUM_TILES_Y)
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
#else /* NUM_TILES_Y */
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
#endif /* NUM_TILES_Y */
*((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
*((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
*((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
*((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
*((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
*((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
*((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
*((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
*((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
*((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
*((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
*((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
*((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
*((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
*((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
*((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
*((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
*((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
*((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
*((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
*((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
*((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
*((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
*((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
*((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
*((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
*((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
*((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
*((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
*((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
*((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
*((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
*((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
*((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
*((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
*((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
*((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
*((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
*((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
*((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
*((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
*((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
*((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
*((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
*((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
*((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
*((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
*((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
*((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
*((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
*((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
*((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
*((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
*((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
*((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
*((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
*((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
*((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
*((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
*((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
*((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
*((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
*((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
*((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
/** This OpenCL kernel computes the input transform when the kernel size is 7x7/7x1/1x7 and the output tile is 2x2/7x1/1x7 when the data layout is NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note If this kernel is used to perform Winograd input transform 7x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note If this kernel is used to perform Winograd input transform 1x7, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
const int x = get_global_id(0);
const int y = get_global_id(1);
#if defined(NUM_TILES_Y)
const int z = get_global_id(2) % NUM_TILES_Y;
const int b = get_global_id(2) / NUM_TILES_Y;
#else /* defined(NUM_TILES_Y) */
const int z = get_global_id(2);
#endif /* defined(NUM_TILES_Y) */
// Compute input address
#if defined(NUM_TILES_Y)
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
#else /* defined(NUM_TILES_Y) */
__global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
#endif /* defined(NUM_TILES_Y) */
// Origin coordinates for the width (y) and height (z) in the input tensor
int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
// Coordinates to use to avoid out-of-bound reads
int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1));
int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1));
// Boundary conditions
int8 y_cond0 = y_coord_valid0 == y_coord0;
int8 z_cond0 = z_coord_valid0 == z_coord0;
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
VEC_DATA_TYPE(DATA_TYPE, 8)
out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
VEC_DATA_TYPE(DATA_TYPE, 8)
tmp0 = ((VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.0f) * in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond);
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
tmp0 = ((VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.0f) * in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
// Row0
in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
// Row1
in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1);
// Row2
in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2);
// Row3
in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3);
// Row4
in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4);
// Row5
in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5);
// Row6
in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
// Row7
in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7);
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = (DATA_TYPE)36.0f * in_row2 - (DATA_TYPE)13.0f * in_row4 + in_row6;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact1 = (DATA_TYPE)36.0f * in_row1 - (DATA_TYPE)13.0f * in_row3 + in_row5;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact2 = (DATA_TYPE)9.0f * in_row2 - (DATA_TYPE)10.0f * in_row4 + in_row6;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact3 = (DATA_TYPE)18.0f * in_row1 - (DATA_TYPE)20.0f * in_row3 + (DATA_TYPE)2.0f * in_row5;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact4 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact5 = (DATA_TYPE)12.0f * in_row1 - (DATA_TYPE)15.0f * in_row3 + (DATA_TYPE)3.0f * in_row5;
// Calculate intermediate tensors
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = -(DATA_TYPE)36.0f * in_row0 + (DATA_TYPE)49.0f * in_row2 - (DATA_TYPE)14.0f * in_row4 + in_row6;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 - comm_fact1;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 + comm_fact1;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact2 - comm_fact3;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 + comm_fact3;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact4 - comm_fact5;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact4 + comm_fact5;
const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = -(DATA_TYPE)36.0f * in_row1 + (DATA_TYPE)49.0f * in_row3 - (DATA_TYPE)14.0f * in_row5 + in_row7;
VEC_DATA_TYPE(DATA_TYPE, 8)
out0, out1, out2, out3, out4, out5, out6, out7;
OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
OUTPUT_ROW_2x2_7x7(out1, tmp1, comm_fact0);
OUTPUT_ROW_2x2_7x7(out2, tmp2, comm_fact0);
OUTPUT_ROW_2x2_7x7(out3, tmp3, comm_fact0);
OUTPUT_ROW_2x2_7x7(out4, tmp4, comm_fact0);
OUTPUT_ROW_2x2_7x7(out5, tmp5, comm_fact0);
OUTPUT_ROW_2x2_7x7(out6, tmp6, comm_fact0);
OUTPUT_ROW_2x2_7x7(out7, tmp7, comm_fact0);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Store values across the channels
#if defined(NUM_TILES_Y)
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
#else /* NUM_TILES_Y */
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
#endif /* NUM_TILES_Y */
*((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
*((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
*((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
*((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
*((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
*((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
*((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
*((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
*((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
*((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
*((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
*((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
*((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
*((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
*((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
*((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
*((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
*((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
*((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
*((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
*((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
*((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
*((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
*((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
*((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
*((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
*((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
*((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
*((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
*((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
*((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
*((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
*((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
*((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
*((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
*((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
*((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
*((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
*((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
*((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
*((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
*((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
*((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
*((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
*((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
*((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
*((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
*((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
*((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
*((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
*((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
*((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
*((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
*((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
*((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
*((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
*((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
*((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
*((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
*((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
*((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
*((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
*((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
*((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
}
#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 for data layout NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_4x1_5x1_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 7x1 and the output tile is 2x1 for data layout NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=7
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
* @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_2x1_7x1_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_2x2_7x7_stepz1_nhwc(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
#endif // defined(NUM_TILES_Y) && defined(SRC_DIM_1) && defined(SRC_DIM_2)
#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4 for data layout NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x4_1x5_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
/** This OpenCL kernel computes the input transform when the kernel size is 1x7 and the output tile is 1x2 for data layout NHWC
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
* @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
* @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
* @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
* @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=7
* @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
* @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
*/
__kernel void winograd_input_transform_1x2_1x7_stepz1_nhwc(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
uint src_stride_w,
uint dst_stride_w)
{
winograd_input_transform_2x2_7x7_stepz1_nhwc(src_ptr,
src_stride_x,
src_step_x,
src_stride_y,
src_step_y,
src_stride_z,
src_step_z,
src_offset_first_element_in_bytes,
dst_ptr,
dst_stride_x,
dst_step_x,
dst_stride_y,
dst_step_y,
dst_stride_z,
dst_step_z,
dst_offset_first_element_in_bytes,
src_stride_w,
dst_stride_w);
}
#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
)"