blob: 42a57009f30294a3e3466c92e5490db534c5ab18 [file] [log] [blame]
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef RUY_RUY_MUL_PARAMS_H_
#define RUY_RUY_MUL_PARAMS_H_
#include <cstdint>
#include <limits>
#include <type_traits>
#include "ruy/check_macros.h"
#include "ruy/size_util.h"
namespace ruy {
// Enumeration to designate which dimension is the 'channels', for MulParams
// features that are 'per-channel', namely the bias-vector and the quantized
// multiplier.
enum class ChannelDimension : std::int8_t {
// kRow means that 'per-channel' means 'per row of the destination matrix'
kRow,
// kCol means that 'per-channel' means 'per column of the destination matrix'
kCol
};
namespace detail {
template <typename tAccumScalar, typename tDstScalar>
struct MulParamsStorage;
}
// MulParams describes all about a matrix multiplication that
// isn't encoded in the LHS, RHS and destination matrices. Some of that
// information is encoded as compile-time constants and types (for instance, the
// choice of accumulator type, AccumScalar). Some of that information is encoded
// as runtime values (for instance, the optional bias vector).
//
// Template parameters:
// AccumScalar: Accumulator type. The type of accumulators used to compute the
// dot-products before being ultimately casted to the destination type.
// DstScalar: The destination scalar type.
//
// Constraints on these template parameters (see also the ruy::Mul comment):
// * If DstScalar is floating-point then AccumScalar must also be.
// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
// in that integral case, there is a mode switch:
// - If DstScalar is std::int32_t then the multiplier_* fields are all
// disabled, and ruy::Mul will just return raw (unscaled) accumulators.
// - If DstScalar is not std::int32_t then the multiplier_* fields are
// enabled, and ruy::Mul will use them to scale internal std::int32_t
// accumulators before casting them to the DstScalar type. The default
// values are such that the effective multiplier is 1 (no scaling).
//
// For the latter case (DstScalar integral and narrower than std::int32_t),
// reference code can be found in the implementation of ruy::ApplyMultiplier.
// If you look there, you'll find warnings like this:
//
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Warning: this code is not meant to be bit-exact-normative.
// Please refer to the class comment of ruy::MulParams, in mul_params.h.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//
// The explanation of this warning is that as of early 2021, we still don't know
// whether it is advisable to let this code as-is have normative value, or
// whether that would become advisable after some specific final change.
//
// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
// bit-exactly to this reference, but we also know that x86 could be faster if
// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
// know that this particular reference code is inherently better than other
// forms that could perform better on these architectures --- in fact, the
// alternative that was proposed in [2] as better performing on ARM Cortex-M
// is also inherently more accurate thanks to rounding only once, but it would
// perform worse on both ARM NEON, and x86.
//
// In fact, if we look at other hardware architectures beyond current Ruy
// targets, namely "hardware accelerators", it becomes clear that there is no
// hope for any form of this to be efficiently implementable simultaneously on
// all current relevant hardware. Indeed, some accelerators prefer to perform
// the multiplication in IEEE float32, others in IEEE float16, others in
// bfloat16, others in 16-bit fixed-point...
//
// See:
// [1] https://github.com/google/ruy/pull/227
// [2] https://github.com/tensorflow/tensorflow/issues/25087
template <typename tAccumScalar, typename tDstScalar>
class MulParams final {
public:
using AccumScalar = tAccumScalar;
using DstScalar = tDstScalar;
// The bias vector data, if not null.
const AccumScalar* bias() const { return storage_.bias; }
void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
// of the multiplier by which accumulators are multiplied before being casted
// to the destination type.
AccumScalar multiplier_fixedpoint() const {
return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
}
void set_multiplier_fixedpoint(const AccumScalar value) {
set_perchannel(false);
storage_.multiplier_fixedpoint = value;
}
// Only for non-floating-point cases. The exponent part of the aforementioned
// multiplier.
int multiplier_exponent() const {
return storage_.perchannel ? 0 : storage_.multiplier_exponent;
}
void set_multiplier_exponent(const int value) {
set_perchannel(false);
storage_.multiplier_exponent = value;
}
// Per-channel variant of multiplier_fixedpoint. Setting this switches
// to per-channel mode, where `multiplier_fixedpoint` and
// `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
// and `multiplier_exponent_perchannel` are used instead.
//
// This must point to a buffer of as many values as there are rows in the
// destination matrix. Each row of the destination matrix will use the
// corresponding buffer element instead of multiplier_fixedpoint.
const AccumScalar* multiplier_fixedpoint_perchannel() const {
return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
: nullptr;
}
void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) {
set_perchannel(true);
storage_.multiplier_fixedpoint_perchannel = ptr;
}
// Per-channel variant of multiplier_exponent. Same comments as for
// multiplier_fixedpoint_perchannel.
const int* multiplier_exponent_perchannel() const {
return storage_.perchannel ? storage_.multiplier_exponent_perchannel
: nullptr;
}
void set_multiplier_exponent_perchannel(const int* ptr) {
set_perchannel(true);
storage_.multiplier_exponent_perchannel = ptr;
}
// min clamp bound of destination values.
DstScalar clamp_min() const { return storage_.clamp_min; }
void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; }
// max clamp bound of destination values.
DstScalar clamp_max() const { return storage_.clamp_max; }
void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; }
// Designates which dimension is the 'channels', for per-channel features
// such as bias-addition and per-channel quantization multipliers.
ChannelDimension channel_dimension() const {
return storage_.channel_dimension;
}
void set_channel_dimension(ChannelDimension value) {
storage_.channel_dimension = value;
}
// Specifies the upward rounding of the allocated capacity of per-channel
// buffers such as bias vectors and per-channel quantization multipliers.
// The unit is matrix entries, not bytes.
//
// This value must be a power of two.
//
// The default value, 1, means no upward rounding, meaning that the buffers
// are not required to have a capacity greater than the size of the
// corresponding matrix dimension, i.e. the number of rows (respectively
// columns) of the destination matrix if `channel_dimension()` is kRow
// (respectively kCol).
//
// Higher values allow the implementation to assume that it is OK to access
// these buffers a little past this boundary, which is useful in SIMD
// optimized kernels. In practice, when this value is lower than what the
// kernel requires, ruy has to internally reallocate and copy per-channel
// buffers. When this value is high enough, this reallocation and copy is
// avoided.
//
// When a value greater than 1 is specified, the tail region of the buffer
// (past the end of the values actually corresponding to channels) is required
// to be zero-initialized.
//
// As of 2020, values as high as 16 may be useful on some CPU architectures
// (corresponding to the widest kernels used on any CPU architecture).
int perchannel_buffers_capacity_rounding() const {
return 1 << storage_.perchannel_buffers_capacity_rounding_log2;
}
void set_perchannel_buffers_capacity_rounding(int value) {
// Note: pot_log2 asserts (debug-only) that its argument is a power-of-two.
storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value);
}
private:
detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
void set_perchannel(bool perchannel) {
if (storage_.perchannel == perchannel) {
return;
}
if (perchannel) {
RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
} else {
RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
}
storage_.perchannel = perchannel;
}
};
namespace detail {
// Floating-point case.
template <typename AccumScalar, typename DstScalar>
struct MulParamsStorage final {
static_assert(std::is_floating_point<AccumScalar>::value, "");
static_assert(std::is_floating_point<DstScalar>::value, "");
static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "");
const AccumScalar* bias = nullptr;
DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity();
DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity();
ChannelDimension channel_dimension = ChannelDimension::kRow;
std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
// Data members that are disabled in this case are left as `static constexpr`
// so that one can write some generic code.
static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
nullptr;
static constexpr const int* multiplier_exponent_perchannel = nullptr;
static constexpr AccumScalar multiplier_fixedpoint = 0;
static constexpr int multiplier_exponent = 0;
static constexpr bool perchannel = false;
};
// Specialization for the integer-quantized type, with down-quantization of
// int32 accumulators to a narrower destination scalar type.
template <typename DstScalar>
struct MulParamsStorage<std::int32_t, DstScalar> final {
using AccumScalar = std::int32_t;
static_assert(std::is_integral<DstScalar>::value, "");
static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
const AccumScalar* bias = nullptr;
// union { // This used to be a union, temporarily flattened to debug a crash
const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
// Let the default multiplier be effecively a multiplication by 1, so that
// the matmul behaves as a (saturating) plain integer matmul. Unfortunately
// 1 is not exactly representable in fixedpoint with 0 integer bits, but
// using the highest representable value is a sufficiently good
// approximation: since this specialization of MulParams is for the case
// where DstScalar is at least 2x narrower than MulScalar, the values
// for which there would be a difference will get saturated anyway.
AccumScalar multiplier_fixedpoint = 0;
//};
// union { // This used to be a union, temporarily flattened to debug a crash
const int* multiplier_exponent_perchannel = nullptr;
// See the above comment about the default value of multiplier_fixedpoint.
int multiplier_exponent = 0;
// };
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
ChannelDimension channel_dimension = ChannelDimension::kRow;
bool perchannel = false;
std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
};
// Specialization used in the integer case when outputting raw int32
// accumulators, without down-quantization to a narrower destination scalar
// type. In this case, the feature of clamping destination values is not
// available.
template <>
struct MulParamsStorage<std::int32_t, std::int32_t> final {
using AccumScalar = std::int32_t;
using DstScalar = std::int32_t;
const AccumScalar* bias = nullptr;
ChannelDimension channel_dimension = ChannelDimension::kRow;
std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
// Data members that are disabled in this case are left as `static constexpr`
// so that one can write some generic code.
static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
nullptr;
static constexpr const int* multiplier_exponent_perchannel = nullptr;
static constexpr AccumScalar multiplier_fixedpoint = 0;
static constexpr int multiplier_exponent = 0;
static constexpr DstScalar clamp_min =
std::numeric_limits<DstScalar>::lowest();
static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
static constexpr bool perchannel = false;
};
} // namespace detail
} // namespace ruy
#endif // RUY_RUY_MUL_PARAMS_H_