blob: 27c273830e08b523d6d328441f0c0c13e8470f56 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
#include <cstdint>
#include <limits>
#include <type_traits>
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite {
namespace cpu_backend_gemm {
// Matrix storage order: column-major or row-major.
enum class Order { kColMajor, kRowMajor };
// MatrixParams encapsulates the parameters that Gemm needs about each
// matrix, besides the buffer data pointer.
// Compare to ruy::Matrix, which also encapsulates the data pointer.
// Rationale for leaving the data pointer out of here: doing so
// requires complicated const-correctness mechanics. See
// ruy::ConstCheckingPtr.
template <typename Scalar>
struct MatrixParams {
// Storage layout order. For now we only do plain linear non-strided
// layout. It would be easy to support a stride if needed.
Order order = Order::kColMajor;
// Number of rows of the matrix.
int rows = 0;
// Number of columns of the matrix.
int cols = 0;
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0.
Scalar zero_point = 0;
};
// Enumeration of broad categories of Gemm.
//
// The primary reason for this to exist is to allow Gemm to compile
// only uniform-quantized or only per-channel-quantized code paths.
// This is unneeded with ruy as the back-end, as this is only a runtime
// difference in ruy, but with gemmlowp these really are separate code
// paths and templatizing in a QuantizationFlavor is necessary to avoid
// compiling unused gemmlowp code. Indeed, TFLite currently uses
// uint8 with uniform quantization and int8 with per-channel quantization,
// and does not use uint8 with per-channel. We want to avoid compiling
// the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
//
// It's possible to drop this in the future if gemmlowp goes away and no
// other then-relevant backend library handles quantized paths in a way that
// requires knowing this at compile-time.
enum class QuantizationFlavor {
// Floating-point Gemm: the accumulators are not multiplied by any
// 'multiplier'.
kFloatingPoint,
// Quantized Gemm using a single multiplier for all accumulators.
kIntegerWithUniformMultiplier,
// Quantized Gemm using a separate multipliers for accumulators of each
// row of the destination matrix. This is what is called 'per-channel'
// in GemmParams. Here we use the more specific 'per-row' terminology
// to allow for the possibility of 'per-column' in the future, and to
// allow for that to be a separate code path in some back-end such as
// gemmlowp.
kIntegerWithPerRowMultiplier
};
// Additional parameters that Gemm needs, beyond what falls into
// the MatrixParams that it takes. Compare to ruy::Spec.
//
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
// is useful future-proofing. Think of a float16 path using float32 accum.
//
// QuantizationFlavor is passed here even though it's technically not used
// in this class. This is so that we retain the ability in the future to
// specialize this class for quantization flavor, and this allows for
// Gemm to be templatized in quantization_flavor via the GemmParams that it
// takes, allowing for automatic template parameter deduction to take place,
// so that most call sites don't need to specify a QuantizationFlavor
// (only those that need perchannel quantization do).
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor =
std::is_floating_point<AccumScalar>::value
? QuantizationFlavor::kFloatingPoint
: QuantizationFlavor::kIntegerWithUniformMultiplier>
struct GemmParams {
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
// of the multiplier by which accumulators are multiplied before being casted
// to the destination type.
AccumScalar multiplier_fixedpoint = 0;
// Only for non-floating-point cases. The exponent part of the aforementioned
// multiplier.
int multiplier_exponent = 0;
// Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
// point to a buffer of as many values as there are rows in the destination
// matrix. Each row of the destination matrix will use the corresponding
// buffer element instead of multiplier_fixedpoint.
const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
// Per-channel variant of multiplier_exponent. If not nullptr, this must
// point to a buffer of as many values as there are rows in the destination
// matrix. Each row of the destination matrix will use the corresponding
// buffer element instead of multiplier_exponent.
//
// Either none or both of multiplier_exponent_perchannel and
// multiplier_fixedpoint_perchannel must be nullptr.
const int* multiplier_exponent_perchannel = nullptr;
// The bias vector data, if not null.
const AccumScalar* bias = nullptr;
// min clamp bound of destination values.
DstScalar clamp_min = std::is_floating_point<DstScalar>::value
? -std::numeric_limits<DstScalar>::infinity()
: std::numeric_limits<DstScalar>::lowest();
// max clamp bound of destination values.
DstScalar clamp_max = std::is_floating_point<DstScalar>::value
? std::numeric_limits<DstScalar>::infinity()
: std::numeric_limits<DstScalar>::max();
};
/* Convenience typedefs */
template <typename DstScalar>
using QuantizedGemmParams = GemmParams<std::int32_t, DstScalar>;
using FloatGemmParams = GemmParams<float, float>;
/* Validation functions */
// Note that this uses TFLITE_DCHECK from kernels/internal/compatibility.h
// and not TF_LITE_ASSERT from op_macros.h. We want this to be explicitly
// debug-build-only assertions so that there's not reason not to
// generously validate, and TF_LITE_ASSERT is actually at the moment
// a release-build assertion. See b/131587258.
// Validates self-consistency of GemmParams.
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
void ValidateGemmParams(
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
// Guard consistency of the quantized multiplier fields.
if (quantization_flavor == QuantizationFlavor::kFloatingPoint) {
TFLITE_DCHECK(!params.multiplier_fixedpoint);
TFLITE_DCHECK(!params.multiplier_exponent);
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
} else if (quantization_flavor ==
QuantizationFlavor::kIntegerWithUniformMultiplier) {
TFLITE_DCHECK(params.multiplier_fixedpoint);
// Nothing to check about multiplier_exponent
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
} else if (quantization_flavor ==
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
TFLITE_DCHECK(!params.multiplier_fixedpoint);
TFLITE_DCHECK(!params.multiplier_exponent);
TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
TFLITE_DCHECK(params.multiplier_exponent_perchannel);
}
}
namespace detail {
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
struct ValidateTypes {
// This generic implementation is for quantized flavors.
// kFloatingPoint will be a specialization below.
static_assert(!std::is_floating_point<LhsScalar>::value, "");
static_assert(!std::is_floating_point<RhsScalar>::value, "");
static_assert(!std::is_floating_point<AccumScalar>::value, "");
// No requirement on DstScalar --- we might in the future allow it
// to be floating point even in a quantized Gemm.
};
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
QuantizationFlavor::kFloatingPoint> {
static_assert(std::is_floating_point<LhsScalar>::value, "");
static_assert(std::is_floating_point<RhsScalar>::value, "");
static_assert(std::is_floating_point<AccumScalar>::value, "");
static_assert(std::is_floating_point<DstScalar>::value, "");
};
} // namespace detail
// Validates overall consistency of all the parameters taken by a Gemm call:
// the 3 MatrixParams and the GemmParams.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
void ValidateParams(
const MatrixParams<LhsScalar>& lhs_params,
const MatrixParams<RhsScalar>& rhs_params,
const MatrixParams<DstScalar>& dst_params,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>();
ValidateGemmParams(params);
// For now, Gemm only supports this particular combination of storage orders.
// Actually the generic ruy path already supports all combinations (with
// various performance penalties). On the other hand, gemmlowp and Eigen
// paths would require more source code and larger binary code to handle
// other combinations (because orders are template parameters in gemmlowp
// and Eigen). Since this is TFLite's own internal Gemm library, there is
// no point in supporting more than what TFlite currently uses, and that
// is for now this single combination.
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
}
} // namespace cpu_backend_gemm
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_