reduce igamma instantiations (#70666)
Summary:
Don't compile scalar versions of the kernel (there is no scalar overload), combine igamma and igammac kernels.
Igamma cubin size 10 MB -> 2 MB on V100
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70666
Reviewed By: malfet
Differential Revision: D33431359
Pulled By: ngimel
fbshipit-source-id: 440998f751251be274f40dd035efba08b8969192
diff --git a/aten/src/ATen/native/cuda/IGammaKernel.cu b/aten/src/ATen/native/cuda/IGammaKernel.cu
index 896725d..41dc631 100644
--- a/aten/src/ATen/native/cuda/IGammaKernel.cu
+++ b/aten/src/ATen/native/cuda/IGammaKernel.cu
@@ -510,6 +510,19 @@
return _igam_helper_series(a, x);
}
+template<typename scalar_t>
+struct CalcIgamma{
+ CalcIgamma(bool calc_igammac): calc_igammac_(calc_igammac){}
+ bool calc_igammac_;
+ __device__ scalar_t operator() (scalar_t a, scalar_t b) const {
+ if (calc_igammac_) {
+ return calc_igammac(a,b);
+ } else {
+ return calc_igamma(a,b);
+ }
+ }
+};
+
}
// end of regularized lower & upper incomplete gamma
@@ -519,18 +532,14 @@
void igamma_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_cuda", [&]() {
- gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
- return calc_igamma(a, b);
- });
+ gpu_kernel(iter, CalcIgamma<scalar_t>(false));
});
}
void igammac_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igammac_cuda", [&]() {
- gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
- return calc_igammac(a, b);
- });
+ gpu_kernel(iter, CalcIgamma<scalar_t>(true));
});
}