Support CPU Apply in ATen and implement standard_gamma using it (#4161)
* Support CPU Apply directly in ATen and implement standard_gamma using it.
Main changes in this PR:
1) Added a TH_APPLY-style templatized function for CPU apply calls (currently only 2 and 3 tensor argument
versions are supported, but more are easy to add). In fact, this is basically identical to TH_APPLY, except
it uses ATen functions and the API is a template instead of a macro. The template takes an operation that
is performed on the data (and an indicator to signal early termination); i.e. you don't need to know that
x_data is a pointer to the current data location of x.
2) Refactors the ATen dispatch code to easily generate dispatch code for different subsets of the scalar types.
This is in preference to the template_scalar path, which requires valid specialization of each scalar type. Valid
specializations are particularly annoying with CUDA because you most likely can't put the specializations
in a header so need to write some sort of for-all-scalar-type macro to get the correct specializations.
Currently, we only generate dispatch_all (all scalar types, the equivalent existed already), and
dispatch_cpu_floating_types (which is used by standard_gamma).
3) Implements standard_gamma using the above changes (this is an arbitrary choice, it was the latest
apply macro to be committed). The forward is bound via Declarations.yaml,
the backward via the Apply template, and then they are hooked together in derivatives.yaml. This eliminates
needing to change TH at all going forward, which means one can write idiomatic C++ instead of the TH-style macros
(e.g. TH_MATH_NAME).
* Generate Dispatch code with nicer spacing.
* Small cleanups.
* Fix typo.
* Add TODOs for changing macros, remove dead code.
* Use a lambda function.
* Get rid of early exit.
* Rename Scalar,ScalarType template parameters to CScalar.
* Reorder _standard_gamma_grad parameters.
* Add comments explaining calling convention.
* Don't generate Dispatch.h anymore.
* Get rid of backend specific checks in dispatch.
* Fix empty/scalar check.
diff --git a/aten/src/ATen/CPUApplyUtils.h b/aten/src/ATen/CPUApplyUtils.h
new file mode 100644
index 0000000..ce35daf
--- /dev/null
+++ b/aten/src/ATen/CPUApplyUtils.h
@@ -0,0 +1,258 @@
+#pragma once
+
+#include <sstream>
+
+namespace at {
+
+/*
+ * The basic strategy for apply is as follows:
+ *
+ * 1. Starting with the outermost index, loop until we reach a dimension where the
+ * data is no longer contiguous, i.e. the stride at that dimension is not equal to
+ * the size of the tensor defined by the outer dimensions. Let's call this outer
+ * (contiguous) tensor A. Note that if the Tensor is contiguous, then A is equal
+ * to the entire Tensor. Let's call the inner tensor B.
+ *
+ * 2. We loop through the indices in B, starting at its outermost dimension. For
+ * example, if B is a 2x2 matrix, then we do:
+ *
+ * B[0][0]
+ * B[0][1]
+ * B[1][0]
+ * B[1][1]
+ *
+ * We set the offset into the underlying storage as (storageOffset + stride_B * index_B),
+ * i.e. basically we compute the offset into the storage as we would normally for a
+ * Tensor. But because we are guaranteed the subsequent data is contiguous in memory, we
+ * can simply loop for sizeof(A) iterations and perform the operation, without having to
+ * follow the order described by the strides of A.
+ *
+ * 3. As an optimization, we merge dimensions of A that are contiguous in memory. For
+ * example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, then the first two
+ * dimensions can be merged for the purposes of APPLY, reducing the number of nested
+ * loops.
+ */
+
+static inline void check_correct_backend(const Tensor &t, unsigned int pos) {
+ if (t.type().backend() != Backend::CPU) {
+ runtime_error("Expected tensor at position %d to have CPU Backend, but has %s Backend",
+ pos, toString(t.type().backend()));
+ }
+}
+
+static inline void check_correct_backend(const Tensor& t1, const Tensor &t2) {
+ check_correct_backend(t1, 1);
+ check_correct_backend(t2, 2);
+}
+
+static inline void check_correct_backend(const Tensor& t1, const Tensor &t2, const Tensor &t3) {
+ check_correct_backend(t1, 1);
+ check_correct_backend(t2, 2);
+ check_correct_backend(t3, 3);
+}
+
+// TODO: turn this macro into a proper template
+#define __ATH_TENSOR_APPLYX_PREAMBLE(TYPE, ATENSOR, DIM, ALLOW_CONTIGUOUS) \
+ TYPE *ATENSOR##_data = NULL; \
+ int64_t *ATENSOR##_counter = NULL, *ATENSOR##_sizes = NULL, *ATENSOR##_strides = NULL, *ATENSOR##_dimOffset = NULL; \
+ int64_t ATENSOR##_stride = 0, ATENSOR##_size = 0, ATENSOR##_dim = 0, ATENSOR##_i; \
+ int ATENSOR##_contiguous = ALLOW_CONTIGUOUS && DIM < 0; \
+\
+ if(ATENSOR.sizes().equals({0})) \
+ TH_TENSOR_APPLY_hasFinished = true; \
+ else \
+ { \
+ ATENSOR##_data = ATENSOR.data<TYPE>(); \
+ ATENSOR##_size = 1; \
+ ATENSOR##_stride = 1; \
+ for(ATENSOR##_i = ATENSOR.dim() - 1; ATENSOR##_i >= 0; ATENSOR##_i--) { \
+ if(ATENSOR.sizes()[ATENSOR##_i] != 1) { \
+ if(ATENSOR.strides()[ATENSOR##_i] == ATENSOR##_size && ATENSOR##_i != DIM) \
+ ATENSOR##_size *= ATENSOR.sizes()[ATENSOR##_i]; \
+ else{ \
+ ATENSOR##_contiguous = 0; \
+ break; \
+ } \
+ } \
+ } \
+ if (!ATENSOR##_contiguous) { \
+ /* Find the dimension of contiguous sections */ \
+ ATENSOR##_dim = 1; \
+ for(ATENSOR##_i = ATENSOR.dim() - 2; ATENSOR##_i >= 0; ATENSOR##_i--) \
+ { \
+ if(ATENSOR.strides()[ATENSOR##_i] != ATENSOR.strides()[ATENSOR##_i+1] * ATENSOR.sizes()[ATENSOR##_i+1] || ATENSOR##_i == DIM || ATENSOR##_i+1 == DIM) \
+ ATENSOR##_dim++; \
+ } \
+ /* Allocate an array of 3*dim elements, where dim is the number of contiguous sections */ \
+ ATENSOR##_counter = new int64_t[3*ATENSOR##_dim]; \
+ ATENSOR##_sizes = ATENSOR##_counter + ATENSOR##_dim; \
+ ATENSOR##_strides = ATENSOR##_counter + 2*ATENSOR##_dim; \
+ TH_TENSOR_dim_index = ATENSOR##_dim-1; \
+ ATENSOR##_dimOffset = (DIM == ATENSOR.dim()-1) ? &ATENSOR##_i : &ATENSOR##_counter[DIM]; \
+ ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR.dim()-1]; \
+ ATENSOR##_strides[TH_TENSOR_dim_index] = ATENSOR.strides()[ATENSOR.dim()-1]; \
+ /* ATENSOR##_counter tracks where we are in the storage. The offset into the */ \
+ /* storage is given by storage_offset + (i * j), where i is the stride */ \
+ /* vector and j is tensor_counter vector. This sets the starting position for the loop. */ \
+ for(ATENSOR##_i = ATENSOR##_dim-1; ATENSOR##_i >= 0; --ATENSOR##_i) { \
+ ATENSOR##_counter[ATENSOR##_i] = 0; \
+ } \
+ for(ATENSOR##_i = ATENSOR.dim()-2; ATENSOR##_i >= 0; --ATENSOR##_i) { \
+ if (ATENSOR.strides()[ATENSOR##_i] == ATENSOR.strides()[ATENSOR##_i+1] * ATENSOR.sizes()[ATENSOR##_i+1] && ATENSOR##_i != DIM && ATENSOR##_i+1 != DIM) { \
+ ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR##_i] * ATENSOR##_sizes[TH_TENSOR_dim_index]; \
+ if (DIM != ATENSOR.dim()-1 && ATENSOR##_i < DIM) \
+ ATENSOR##_dimOffset--; \
+ } else { \
+ --TH_TENSOR_dim_index; \
+ ATENSOR##_sizes[TH_TENSOR_dim_index] = ATENSOR.sizes()[ATENSOR##_i]; \
+ ATENSOR##_strides[TH_TENSOR_dim_index] = ATENSOR.strides()[ATENSOR##_i]; \
+ } \
+ } \
+ /* Size of the inner most section */ \
+ ATENSOR##_size = ATENSOR##_sizes[ATENSOR##_dim-1]; \
+ /* Stride of the inner most section */ \
+ ATENSOR##_stride = ATENSOR##_strides[ATENSOR##_dim-1]; \
+ } \
+ } \
+ ATENSOR##_i = 0;
+
+// TODO: turn this macro into a proper template
+#define __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(ATENSOR, ALWAYS_UPDATE) \
+ if(ATENSOR##_i == ATENSOR##_size || ALWAYS_UPDATE) \
+ { \
+ if(ATENSOR##_contiguous) \
+ break; \
+\
+ if(ATENSOR##_dim == 1) \
+ break; \
+\
+ /* Reset pointer to beginning of loop */ \
+ ATENSOR##_data -= ATENSOR##_size*ATENSOR##_stride; \
+ for(ATENSOR##_i = ATENSOR##_dim-2; ATENSOR##_i >= 0; ATENSOR##_i--) \
+ { \
+ ATENSOR##_counter[ATENSOR##_i]++; \
+ /* Jump ahread by the stride of this dimension */ \
+ ATENSOR##_data += ATENSOR##_strides[ATENSOR##_i]; \
+\
+ if(ATENSOR##_counter[ATENSOR##_i] == ATENSOR##_sizes[ATENSOR##_i]) \
+ { \
+ if(ATENSOR##_i == 0) \
+ { \
+ TH_TENSOR_APPLY_hasFinished = true; \
+ break; \
+ } \
+ else \
+ { \
+ /* Reset the pointer to the beginning of the chunk defined by this dimension */ \
+ ATENSOR##_data -= ATENSOR##_counter[ATENSOR##_i]*ATENSOR##_strides[ATENSOR##_i]; \
+ ATENSOR##_counter[ATENSOR##_i] = 0; \
+ } \
+ } \
+ else \
+ break; \
+ } \
+ ATENSOR##_i = 0; \
+ }
+
+template <typename CScalar, typename Op>
+void CPU_tensor_apply2_dim(Tensor& tensor1, Tensor& tensor2, int64_t dim, Op op) {
+ check_correct_backend(tensor1, tensor2);
+ bool TH_TENSOR_APPLY_hasFinished = false;
+ int64_t TH_TENSOR_dim_index = 0;
+ __ATH_TENSOR_APPLYX_PREAMBLE(CScalar, tensor1, dim, 1)
+ __ATH_TENSOR_APPLYX_PREAMBLE(CScalar, tensor2, dim, 1)
+ auto t1_numel = tensor1.numel();
+ auto t2_numel = tensor2.numel();
+ if(t1_numel != t2_numel) {
+ std::ostringstream oss;
+ oss << "inconsistent tensor size, expected " << tensor1.sizes() << " and " << tensor2.sizes()
+ << " to have the same number of elements, but got " << t1_numel << " and " << t2_numel << " elements respectively";
+ throw std::runtime_error(oss.str());
+ }
+ while(!TH_TENSOR_APPLY_hasFinished)
+ {
+ /* Loop through the inner most region of the Tensor */
+ for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size; tensor1_i++, tensor2_i++, tensor1_data += tensor1_stride, tensor2_data += tensor2_stride)
+ {
+ op(*tensor1_data, *tensor2_data);
+ }
+ __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
+ __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
+ }
+ if(tensor1_counter != NULL)
+ delete [] tensor1_counter;
+ if(tensor2_counter != NULL)
+ delete [] tensor2_counter;
+}
+
+/*
+ Apply a pointwise operator to two tensors.
+
+ The calling convention for op is a function/functor that takes takes two references to
+ type CScalar; at least one of these references should be non-const in order to write the output.
+ For example, to compute a = b^2, op would be of the form:
+ [](CScalar &a_val, const CScalar &b_val) { a_val = b_val * b_val; };
+*/
+template<typename CScalar, typename Op>
+void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, Op op) {
+ CPU_tensor_apply2_dim<CScalar, Op>(tensor1, tensor2, -1, op);
+}
+
+template <typename CScalar, typename Op>
+void CPU_tensor_apply3_dim(Tensor &tensor1, Tensor& tensor2, Tensor& tensor3, int64_t dim, Op op) {
+ check_correct_backend(tensor1, tensor2, tensor3);
+ bool TH_TENSOR_APPLY_hasFinished = false;
+ int64_t TH_TENSOR_dim_index = 0;
+ __ATH_TENSOR_APPLYX_PREAMBLE(CScalar, tensor1, dim, 1)
+ __ATH_TENSOR_APPLYX_PREAMBLE(CScalar, tensor2, dim, 1)
+ __ATH_TENSOR_APPLYX_PREAMBLE(CScalar, tensor3, dim, 1)
+
+ int elements_equal = 1;
+ auto t1_numel = tensor1.numel();
+ auto t2_numel = tensor2.numel();
+ auto t3_numel = tensor3.numel();
+ if(t1_numel!= t2_numel) {
+ elements_equal = 0;
+ } else if(t1_numel != t3_numel) {
+ elements_equal = 0;
+ }
+ if (elements_equal == 0) {
+ std::ostringstream oss;
+ oss << "inconsistent tensor size, expected " << tensor1.sizes() << ", " << tensor2.sizes() << ", and " << tensor3.sizes()
+ << " to have the same number of elements, but got " << t1_numel << ", " << t2_numel << ", and " << t3_numel << " elements respectively";
+ throw std::runtime_error(oss.str());
+ }
+
+ while(!TH_TENSOR_APPLY_hasFinished)
+ {
+ /* Loop through the inner most region of the Tensor */
+ for(; tensor1_i < tensor1_size && tensor2_i < tensor2_size && tensor3_i < tensor3_size; tensor1_i++, tensor2_i++, tensor3_i++, tensor1_data += tensor1_stride, tensor2_data += tensor2_stride, tensor3_data += tensor3_stride)
+ {
+ op(*tensor1_data, *tensor2_data, *tensor3_data);
+ }
+ __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor1, 0)
+ __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor2, 0)
+ __ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor3, 0)
+ }
+ if(tensor1_counter != NULL)
+ delete [] tensor1_counter;
+ if(tensor2_counter != NULL)
+ delete [] tensor2_counter;
+ if(tensor3_counter != NULL)
+ delete [] tensor3_counter;
+}
+
+/*
+ Apply a pointwise operator to three tensors.
+
+ The calling convention for op is a function/functor that takes takes three references to
+ type CScalar; at least one of these references should be non-const in order to write the output.
+ For example, to compute a = b + c, op would be of the form:
+ [](CScalar &a_val, const CScalar &b_val, const CScalar &c_val) { a_val = b_val + c_val; };
+*/
+template<typename CScalar, typename Op>
+void CPU_tensor_apply3(Tensor tensor1, Tensor tensor2, Tensor tensor3, Op op) {
+ CPU_tensor_apply3_dim<CScalar, Op>(tensor1, tensor2, tensor3, -1, op);
+}
+
+}
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 6a73f5e..76740cd 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -3869,6 +3869,26 @@
- THTensor* self
]]
[[
+ name: _standard_gamma
+ types:
+ - floating_point
+ backends:
+ - CPU
+ return: argument 0
+ variants:
+ - method
+ - function
+ options:
+ - cname: standard_gamma
+ arguments:
+ - arg: THTensor* output
+ output: True
+ - arg: THGenerator* generator
+ default: THPGenerator_TH_CData(THPDefaultGenerator)
+ kwarg_only: True
+ - THTensor* self
+]]
+[[
name: tensor
return: THTensor*
cpu_half: True
diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h
new file mode 100644
index 0000000..a56c50a
--- /dev/null
+++ b/aten/src/ATen/Dispatch.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include <utility>
+
+namespace at {
+
+template<template <typename> class F, typename ... Args>
+auto dispatch_all(const Type& the_type, const char *name, Args&&... args)
+ -> decltype(F<double>::apply(std::forward<Args>(args)...)) {
+
+ switch(the_type.scalarType()) {
+ case ScalarType::Byte:
+ return F<uint8_t>::apply(std::forward<Args>(args)...);
+ case ScalarType::Char:
+ return F<int8_t>::apply(std::forward<Args>(args)...);
+ case ScalarType::Double:
+ return F<double>::apply(std::forward<Args>(args)...);
+ case ScalarType::Float:
+ return F<float>::apply(std::forward<Args>(args)...);
+ case ScalarType::Int:
+ return F<int>::apply(std::forward<Args>(args)...);
+ case ScalarType::Long:
+ return F<int64_t>::apply(std::forward<Args>(args)...);
+ case ScalarType::Short:
+ return F<int16_t>::apply(std::forward<Args>(args)...);
+ case ScalarType::Half:
+ return F<Half>::apply(std::forward<Args>(args)...);
+ default:
+ runtime_error("%s not implemented for '%s'", name, the_type.toString());
+ }
+}
+template<template <typename> class F, typename ... Args>
+auto dispatch_floating_types(const Type& the_type, const char *name, Args&&... args)
+ -> decltype(F<double>::apply(std::forward<Args>(args)...)) {
+ switch(the_type.scalarType()) {
+ case ScalarType::Double:
+ return F<double>::apply(std::forward<Args>(args)...);
+ case ScalarType::Float:
+ return F<float>::apply(std::forward<Args>(args)...);
+ default:
+ runtime_error("%s not implemented for '%s'", name, the_type.toString());
+ }
+}
+
+
+}
diff --git a/aten/src/ATen/dispatch_macros.py b/aten/src/ATen/dispatch_macros.py
deleted file mode 100644
index 2bb9061..0000000
--- a/aten/src/ATen/dispatch_macros.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from code_template import CodeTemplate
-
-CASE_TEMPLATE = CodeTemplate("""\
-case ${TypeID}:
- return F<${ScalarType}>::${Backend}(the_type,std::forward<Args>(args)...);
-""")
-
-MACRO_TEMPLATE = CodeTemplate("""\
-#pragma once
-
-namespace at {
-
-template<template <typename> class F, typename ... Args>
-auto dispatch(const Type & the_type, Args&&... args)
- -> decltype(F<double>::CPU(the_type,std::forward<Args>(args)...)) {
- switch(the_type.ID()) {
- ${cases}
- default:
- runtime_error("dispatch() not implemented for '%s'",the_type.toString());
- }
-}
-
-}
-""")
-
-
-def create_dispatch(all_types):
- cases = []
- for typ in all_types:
- if typ['Density'] != 'Sparse':
- cases.append(CASE_TEMPLATE.substitute(typ))
- return MACRO_TEMPLATE.substitute(cases=cases)
-
-
-def create(all_types):
- return create_dispatch(all_types)
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 36e1aed..962b3cf 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -7,7 +7,6 @@
import native_parse
import preprocess_declarations
import function_wrapper
-import dispatch_macros
import copy_wrapper
from code_template import CodeTemplate
@@ -119,15 +118,16 @@
densities = ['Dense', 'Sparse']
+# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
scalar_types = [
- ('Byte', 'uint8_t', 'Long', 'uint8_t'),
- ('Char', 'int8_t', 'Long', 'int8_t'),
- ('Double', 'double', 'Double', 'double'),
- ('Float', 'float', 'Double', 'float'),
- ('Int', 'int', 'Long', 'int32_t'),
- ('Long', 'int64_t', 'Long', 'int64_t'),
- ('Short', 'int16_t', 'Long', 'int16_t'),
- ('Half', 'Half', 'Double', 'THHalf'),
+ ('Byte', 'uint8_t', 'Long', 'uint8_t', False),
+ ('Char', 'int8_t', 'Long', 'int8_t', False),
+ ('Double', 'double', 'Double', 'double', True),
+ ('Float', 'float', 'Double', 'float', True),
+ ('Int', 'int', 'Long', 'int32_t', False),
+ ('Long', 'int64_t', 'Long', 'int64_t', False),
+ ('Short', 'int16_t', 'Long', 'int16_t', False),
+ ('Half', 'Half', 'Double', 'THHalf', True),
]
# shared environment for non-derived base classes Type.h Tensor.h Storage.h
@@ -179,7 +179,7 @@
def generate_storage_type_and_tensor(backend, density, scalar_type, declarations):
- scalar_name, c_type, accreal, th_scalar_type = scalar_type
+ scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
env = {}
density_tag = 'Sparse' if density == 'Sparse' else ''
th_density_tag = 'S' if density == 'Sparse' else ''
@@ -188,6 +188,8 @@
env['ScalarType'] = c_type
env['THScalarType'] = th_scalar_type
env['AccScalarName'] = accreal
+ env['isFloatingType'] = is_floating_type
+ env['isIntegralType'] = not is_floating_type
env['Storage'] = "{}{}Storage".format(backend, scalar_name)
env['Type'] = "{}{}{}Type".format(density_tag, backend, scalar_name)
env['Tensor'] = "{}{}{}Tensor".format(density_tag, backend, scalar_name)
@@ -301,7 +303,7 @@
def declare_outputs():
files = ['Declarations.yaml', 'Type.h', 'Type.cpp', 'Tensor.h',
'TensorMethods.h', 'Functions.h',
- 'Dispatch.h', 'Copy.cpp', 'NativeFunctions.h']
+ 'Copy.cpp', 'NativeFunctions.h']
for f in files:
file_manager.will_write(f)
for fname in sorted(generators.keys()):
@@ -351,7 +353,7 @@
file_manager.write('Tensor.h', TENSOR_H.substitute(top_env))
file_manager.write('TensorMethods.h', TENSOR_METHODS_H.substitute(top_env))
file_manager.write('Functions.h', FUNCTIONS_H.substitute(top_env))
- file_manager.write('Dispatch.h', dispatch_macros.create(all_types))
+
file_manager.write('Copy.cpp', copy_wrapper.create(all_types))
file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H.substitute(top_env))
diff --git a/aten/src/ATen/native/NativeFunctions.cpp b/aten/src/ATen/native/NativeFunctions.cpp
index 7535cdb..c86a8d3 100644
--- a/aten/src/ATen/native/NativeFunctions.cpp
+++ b/aten/src/ATen/native/NativeFunctions.cpp
@@ -1,7 +1,9 @@
#include "ATen/ATen.h"
+#include "ATen/CPUApplyUtils.h"
+#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
-#include "ATen/PinnedMemoryAllocator.h"
#include "ATen/NativeFunctions.h"
+#include "ATen/PinnedMemoryAllocator.h"
#include "ATen/WrapDimUtils.h"
#include <functional>
#include <numeric>
@@ -595,5 +597,75 @@
throw std::runtime_error("not implemented");
}
+
+
+// TODO Replace this with more accurate digamma().
+template <typename CScalar>
+static inline CScalar digamma_one(CScalar x) {
+ const CScalar eps = x * 1e-2;
+ return (std::lgamma(x + eps) - std::lgamma(x - eps)) / (eps + eps);
+}
+
+/** Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
+ for random number x drawn from a standard Gamma distribution Gamma(alpha).
+*/
+template <typename CScalar>
+static inline CScalar standard_gamma_grad_one(CScalar alpha, CScalar x) {
+ // Use an asymptotic approximation for small x.
+ if (x < 0.2f) {
+ const CScalar a0 = 1 / alpha;
+ const CScalar a1 = 1 / (alpha + 1);
+ const CScalar a2 = 1 / (alpha + 2);
+ const CScalar pow_x_alpha = std::pow(x, alpha);
+ const CScalar gamma_pdf = std::pow(x, alpha - 1) * std::exp(-x);
+ const CScalar gamma_cdf = pow_x_alpha * (a0 - x*a1 + 0.5f*x*x*a2);
+ const CScalar gamma_cdf_alpha = (std::log(x) - digamma_one(alpha)) * gamma_cdf
+ - pow_x_alpha * (a0*a0 - x*a1*a1 + 0.5f*x*x*a2*a2);
+ const CScalar result = -gamma_cdf_alpha / gamma_pdf;
+ return std::isnan(result) ? 0 : result;
+ }
+
+ // Use an asymptotic approximation for large alpha.
+ if (alpha > 50.0f) {
+ return std::sqrt(x / alpha);
+ }
+
+ // Use a bivariate rational approximation to the reparameterized gradient.
+ const CScalar u = std::log(x / alpha);
+ const CScalar v = std::log(alpha);
+ static const CScalar coef_uv[3][8] = {
+ {0.16028008, -0.088064309, 0.019630876, -0.0016920282,
+ 1.0, 0.36659853, 0.10843863, 0.0066895454},
+ {0.521894, 0.16095838, 0.06237597, 0.0023884253,
+ 0.083457714, 0.0073297628, -0.0059299053, -0.00093720389},
+ {-0.0031143957, -0.012143877, -0.0057656484, -0.00064847254,
+ 0.0087262576, -0.00022820524, 1.8871047e-05, 9.6307964e-06},
+ };
+ CScalar coef_v[8];
+ for (int i = 0; i < 8; ++ i) {
+ coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
+ }
+ const CScalar p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
+ const CScalar q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
+ return std::exp(p / q);
+}
+
+template <typename CScalar>
+struct StandardGammaGradOp {
+ static void apply(Tensor& ret, const Tensor& self, const Tensor& output) {
+ CPU_tensor_apply3<CScalar>(ret, self, output,
+ [](CScalar& ret_val, const CScalar& self_val, const CScalar &output_val) {
+ ret_val = standard_gamma_grad_one(self_val, output_val);
+ }
+ );
+ }
+};
+
+Tensor _standard_gamma_grad(const Tensor& self, const Tensor& output) {
+ Tensor ret = self.type().tensor(self.sizes());
+ dispatch_floating_types<StandardGammaGradOp>(self.type(), "_standard_gamma_grad", ret, self, output);
+ return ret;
+}
+
}
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index cb0e4d5..925e2d8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -105,6 +105,8 @@
- func: is_signed(Tensor self) -> bool
template_scalar: True
+- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
+
- func: matmul(Tensor self, Tensor other) -> Tensor
- func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor)
diff --git a/aten/src/ATen/test/native_test.cpp b/aten/src/ATen/test/native_test.cpp
index 4787df3..614b767 100644
--- a/aten/src/ATen/test/native_test.cpp
+++ b/aten/src/ATen/test/native_test.cpp
@@ -124,5 +124,33 @@
ASSERT_THROWS(d5.matmul(d5wrong), "must match the size");
}
+ // _standard_gamma_grad
+ {
+ // check empty
+ auto empty = T.ones({0});
+ ASSERT_EQUAL(empty, empty._standard_gamma_grad(empty));
+
+ // check scalar equals one element
+ auto one_scalar = T.ones({}).mul(5);
+ auto one_with_dim = T.ones({1}).mul(5);
+ ASSERT_ALLCLOSE(one_scalar._standard_gamma_grad(one_scalar),
+ one_with_dim._standard_gamma_grad(one_with_dim).sum());
+
+ // check types
+ Type & DT = CPU(kDouble);
+ auto t1 = T.randn({3, 4});
+ auto t2 = DT.randn({3, 4});
+ ASSERT_THROWS(t1._standard_gamma_grad(t2), "expected scalar type");
+ if(at::hasCUDA()) {
+ Type & CT = CUDA(kFloat);
+ auto ct1 = CT.randn({3, 4});
+ auto ct2 = CT.randn({3, 4});
+
+ ASSERT_THROWS(ct1._standard_gamma_grad(ct2), "not implemented");
+ ASSERT_THROWS(ct1._standard_gamma_grad(t2), "not implemented");
+ ASSERT_THROWS(t1._standard_gamma_grad(ct2), "CUDA Backend");
+ }
+ }
+
return 0;
}
diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp
index 46a971b..a0a129b 100644
--- a/aten/src/ATen/test/scalar_test.cpp
+++ b/aten/src/ATen/test/scalar_test.cpp
@@ -15,19 +15,16 @@
template<typename scalar_type>
struct Foo {
- static void CPU(const Type & t, Tensor a, Tensor b) {
+ static void apply(Tensor a, Tensor b) {
scalar_type s = 1;
- cout << "hello, dispatch: " << t.toString() << s << "\n";
+ cout << "hello, dispatch: " << a.type().toString() << s << "\n";
auto data = (scalar_type*)a.data_ptr();
(void)data;
}
- static void CUDA(const Type & t, Tensor a, Tensor b) {
- }
};
template<>
struct Foo<Half> {
- static void CPU(const Type & t, Tensor a, Tensor b) {}
- static void CUDA(const Type & t, Tensor a, Tensor b) {}
+ static void apply(Tensor a, Tensor b) {}
};
void test_ctors() {
@@ -142,7 +139,7 @@
ASSERT(what.toTensor().type().scalarType() == kLong);
ASSERT(Scalar(CPU(kFloat).ones({})).toTensor().type().scalarType() == kFloat);
- dispatch<Foo>(x.type(),x,prev_h);
+ dispatch_all<Foo>(x.type(),"foo",x,prev_h);
// test direct C-scalar type conversions
try {
diff --git a/aten/src/TH/generic/THTensorMath.c b/aten/src/TH/generic/THTensorMath.c
index dd473e5..10629ab 100644
--- a/aten/src/TH/generic/THTensorMath.c
+++ b/aten/src/TH/generic/THTensorMath.c
@@ -3486,57 +3486,6 @@
return (TH_MATH_NAME(lgamma)(x + eps) - TH_MATH_NAME(lgamma)(x - eps)) / (eps + eps);
}
-/** Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
- for random number x drawn from a standard Gamma distribution Gamma(alpha).
-*/
-static inline real THTensor_(standard_gamma_grad_one)(real x, real alpha) {
- // Use an asymptotic approximation for small x.
- if (x < 0.2f) {
- const real a0 = 1 / alpha;
- const real a1 = 1 / (alpha + 1);
- const real a2 = 1 / (alpha + 2);
- const real pow_x_alpha = TH_MATH_NAME(pow)(x, alpha);
- const real gamma_pdf = TH_MATH_NAME(pow)(x, alpha - 1) * TH_MATH_NAME(exp)(-x);
- const real gamma_cdf = pow_x_alpha * (a0 - x*a1 + 0.5f*x*x*a2);
- const real gamma_cdf_alpha = (TH_MATH_NAME(log)(x) - THTensor_(digamma_one)(alpha)) * gamma_cdf
- - pow_x_alpha * (a0*a0 - x*a1*a1 + 0.5f*x*x*a2*a2);
- const real result = -gamma_cdf_alpha / gamma_pdf;
- return isnan(result) ? 0 : result;
- }
-
- // Use an asymptotic approximation for large alpha.
- if (alpha > 50.0f) {
- return TH_MATH_NAME(sqrt)(x / alpha);
- }
-
- // Use a bivariate rational approximation to the reparameterized gradient.
- const real u = TH_MATH_NAME(log)(x / alpha);
- const real v = TH_MATH_NAME(log)(alpha);
- static const real coef_uv[3][8] = {
- {0.16028008, -0.088064309, 0.019630876, -0.0016920282,
- 1.0, 0.36659853, 0.10843863, 0.0066895454},
- {0.521894, 0.16095838, 0.06237597, 0.0023884253,
- 0.083457714, 0.0073297628, -0.0059299053, -0.00093720389},
- {-0.0031143957, -0.012143877, -0.0057656484, -0.00064847254,
- 0.0087262576, -0.00022820524, 1.8871047e-05, 9.6307964e-06},
- };
- real coef_v[8];
- for (int i = 0; i < 8; ++ i) {
- coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
- }
- const real p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
- const real q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
- return TH_MATH_NAME(exp)(p / q);
-}
-
-void THTensor_(standard_gamma_grad)(THTensor *self, THTensor *x, THTensor *alpha)
-{
- THTensor_(resizeAs)(self, x);
- TH_TENSOR_APPLY3(real, self, real, x, real, alpha, {
- *self_data = THTensor_(standard_gamma_grad_one)(*x_data, *alpha_data);
- });
-}
-
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
// Assumes x is close to zero.
static inline real THTensor_(beta_grad_alpha_small)(real x, real alpha, real beta) {
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 695ef67..df4b537 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -195,7 +195,6 @@
TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
-TH_API void THTensor_(standard_gamma_grad)(THTensor *self, THTensor *x, THTensor *alpha);
TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alpha, THTensor *total);
#endif
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 6bc7d4c..c813d5b 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -653,6 +653,9 @@
- name: _sparse_mask(Tensor self, SparseTensor mask)
self: not_implemented("_sparse_mask")
+- name: _standard_gamma(Tensor self, Generator generator)
+ self: grad * self._standard_gamma_grad(output)
+
# NN
- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight, bool size_average, bool reduce)
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index b3e319c..87ead71 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -325,7 +325,6 @@
IMPLEMENT_STATELESS(multinomial)
IMPLEMENT_STATELESS(normal)
IMPLEMENT_STATELESS(standard_gamma)
-IMPLEMENT_STATELESS(standard_gamma_grad)
IMPLEMENT_STATELESS(dirichlet_grad)
IMPLEMENT_STATELESS(bernoulli)
IMPLEMENT_STATELESS(range)
@@ -710,7 +709,6 @@
{"multinomial", (PyCFunction)THPModule_multinomial, METH_VARARGS | METH_KEYWORDS, NULL},
{"normal", (PyCFunction)THPModule_normal, METH_VARARGS | METH_KEYWORDS, NULL},
{"_standard_gamma", (PyCFunction)THPModule_standard_gamma, METH_VARARGS | METH_KEYWORDS, NULL},
- {"_standard_gamma_grad", (PyCFunction)THPModule_standard_gamma_grad, METH_VARARGS | METH_KEYWORDS, NULL},
{"_dirichlet_grad", (PyCFunction)THPModule_dirichlet_grad, METH_VARARGS | METH_KEYWORDS, NULL},
{"bernoulli", (PyCFunction)THPModule_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},
{"rand", (PyCFunction)THPModule_rand, METH_VARARGS | METH_KEYWORDS, NULL},
diff --git a/torch/csrc/generic/methods/TensorRandom.cwrap b/torch/csrc/generic/methods/TensorRandom.cwrap
index db6dba5..75d1439 100644
--- a/torch/csrc/generic/methods/TensorRandom.cwrap
+++ b/torch/csrc/generic/methods/TensorRandom.cwrap
@@ -231,24 +231,6 @@
]]
[[
- name: standard_gamma_grad
- types:
- - floating_point
- backends:
- - CPU
- return: argument 0
- variants:
- - function
- options:
- - cname: standard_gamma_grad
- arguments:
- - arg: THTensor* output
- output: True
- - THTensor* x
- - THTensor* alpha
-]]
-
-[[
name: dirichlet_grad
types:
- floating_point
diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py
index 87b8ddc..4ac0e9c 100644
--- a/torch/distributions/gamma.py
+++ b/torch/distributions/gamma.py
@@ -7,25 +7,10 @@
from torch.distributions.utils import broadcast_all
-class _StandardGamma(Function):
- @staticmethod
- def forward(ctx, alpha):
- x = torch._C._standard_gamma(alpha)
- ctx.save_for_backward(x, alpha)
- return x
-
- @staticmethod
- @once_differentiable
- def backward(ctx, grad_output):
- x, alpha = ctx.saved_tensors
- grad = torch._C._standard_gamma_grad(x, alpha)
- return grad_output * grad
-
-
def _standard_gamma(alpha):
if not isinstance(alpha, Variable):
return torch._C._standard_gamma(alpha)
- return _StandardGamma.apply(alpha)
+ return alpha._standard_gamma()
class Gamma(Distribution):