[MPS] Fix gradient issues with NLL and Smooth_L1 loss ops (#94226)
- Fix correctness issues with nll_loss_backward(), smooth_l1_loss_backward() and cross_entropy_backward() by taking grad_output into account when computing those loss ops
- Add numel()==0 check to prevent crashes
- Clean up and formatting
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94226
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm
index 5e46582..086e594 100644
--- a/aten/src/ATen/native/mps/operations/LossOps.mm
+++ b/aten/src/ATen/native/mps/operations/LossOps.mm
@@ -1,13 +1,6 @@
// Copyright © 2022 Apple Inc.
#include <ATen/native/mps/OperationUtils.h>
-#include <ATen/mps/MPSStream.h>
-#include <objc/NSObjCRuntime.h>
-#include <torch/library.h>
-
-#ifdef __OBJC__
-#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
-#endif
namespace at::native {
namespace mps {
@@ -37,12 +30,6 @@
}
}
-// MSELoss
-void mse_loss_out_impl(const Tensor& input, const Tensor& target,
- int64_t reduction, const Tensor& output, const string op_name)
-{
-}
-
Tensor& mse_loss_backward_out_impl(const Tensor& grad_output, const Tensor& input, const Tensor& target,
int64_t reduction, Tensor& grad_input, const string op_name)
{
@@ -313,163 +300,141 @@
// NLLLoss
void nllnd_loss_backward_impl(
-Tensor& grad_input_arg,
-const Tensor& grad_output,
-const Tensor& input_arg,
-const Tensor& target_arg,
-const Tensor& weight,
-int64_t reduction,
-int64_t ignore_index,
-const Tensor& total_weight,
-bool is2D)
-{
- // Empty output
- if(grad_input_arg.numel() == 0)
+ Tensor& grad_input_arg,
+ const Tensor& grad_output_arg,
+ const Tensor& input_arg,
+ const Tensor& target_arg,
+ const Tensor& weight_arg,
+ int64_t reduction,
+ int64_t ignore_index,
+ const Tensor& total_weight,
+ bool is2D) {
+
+ if (grad_input_arg.numel() == 0) {
return;
-
- MPSStream* stream = getCurrentMPSStream();
-
- struct CachedGraph : public MPSCachedGraph
- {
+ }
+ struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* targetTensor_ = nil;
MPSGraphTensor* weightTensor_ = nil;
MPSGraphTensor* totalWeightTensor_ = nil;
MPSGraphTensor* gradInputTensor_ = nil;
+ MPSGraphTensor* gradOutputTensor_ = nil;
};
-
- MPSGraphCache* cache_ = MPSGraphCache::getInstance();
-
+ bool isWeightsArrayValid = weight_arg.defined() && weight_arg.numel() > 0;
+ int64_t channel_dim = grad_input_arg.dim() < 2 ? 0 : 1;
auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg;
+ auto numClasses = grad_input.sizes()[1];
+ auto weight = weight_arg;
+ auto grad_output = grad_output_arg;
+ if (isWeightsArrayValid) {
+ std::vector<int64_t> weightShape(input.dim(), 1);
+ weightShape[1] = input.size(1);
+ weight = weight_arg.view(weightShape);
+ }
+ if (grad_output_arg.dim() < grad_input.dim() && grad_output_arg.dim() > 0) {
+ grad_output = grad_output_arg.unsqueeze(channel_dim);
+ }
@autoreleasepool {
+ string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight})
+ + to_string(numClasses) + ":" + to_string(ignore_index) + ":"
+ + to_string(isWeightsArrayValid) + ":" + reductionToString(reduction);
- auto numClasses = grad_input.sizes()[1];
- bool isWeightsArrayValid = (weight.numel() > 0);
-
- MPSShape* input_shape = getMPSShape(input);
- MPSShape* target_shape = getMPSShape(target);
- MPSShape* weight_shape = getMPSShape(weight);
- MPSShape* total_weight_shape = getMPSShape(total_weight);
-
- NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
-
- string key = "nllnd_loss_backward_impl:" + to_string(numClasses) + ":" +
- to_string(ignore_index) + ":" +
- to_string(isWeightsArrayValid) + ":" +
- reductionToString(reduction) + ":" +
- [ns_shape_key UTF8String] + ":" +
- getMPSTypeString(input.scalar_type()) + ":" +
- getMPSTypeString(target.scalar_type()) + ":" +
- getMPSTypeString(weight.scalar_type()) + ":" +
- getMPSTypeString(total_weight.scalar_type());
- CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-
+ MPSGraphCache* cache_ = MPSGraphCache::getInstance();
+ CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
- MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+ cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
-
@autoreleasepool {
-
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape);
- MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target.scalar_type()), target_shape);
+ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
+ MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
MPSGraphTensor* weightTensor = nil;
- if(isWeightsArrayValid)
- weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(weight.scalar_type()), weight_shape);
- MPSGraphTensor* totalWeightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(total_weight.scalar_type()), total_weight_shape);
+ if (isWeightsArrayValid) {
+ weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
+ }
+ MPSGraphTensor* totalWeightTensor = mpsGraphRankedPlaceHolder(mpsGraph, total_weight);
+ MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor *udpatedTargetTensor = targetTensor;
// Replace ignored_index with length depth + 1 so that oneHotAPI ignores it
- if(ignore_index != -100)
- {
- MPSGraphTensor *mpsGraphIndexTensor = [mpsGraph constantWithScalar: ignore_index
- dataType: MPSDataTypeInt64];
- MPSGraphTensor *mpsGraphDepthPlusOneTensor = [mpsGraph constantWithScalar: (numClasses + 1)
- dataType: MPSDataTypeInt64];
-
- // Equal tensor
- MPSGraphTensor* mpsGraphIsEqualTensor = [mpsGraph equalWithPrimaryTensor: targetTensor
- secondaryTensor: mpsGraphIndexTensor
- name: @"isEqualTensor"];
-
- udpatedTargetTensor = [mpsGraph selectWithPredicateTensor: mpsGraphIsEqualTensor
- truePredicateTensor: mpsGraphDepthPlusOneTensor
- falsePredicateTensor: targetTensor
- name: @"predicateTensor"];
+ if (ignore_index != -100) {
+ MPSGraphTensor *ignoreIndexTensor = [mpsGraph constantWithScalar: ignore_index
+ dataType: MPSDataTypeInt64];
+ MPSGraphTensor *numClassesTensor = [mpsGraph constantWithScalar: (numClasses + 1)
+ dataType: MPSDataTypeInt64];
+ MPSGraphTensor* isEqualTensor = [mpsGraph equalWithPrimaryTensor: targetTensor
+ secondaryTensor: ignoreIndexTensor
+ name: @"isEqualTensor"];
+ udpatedTargetTensor = [mpsGraph selectWithPredicateTensor: isEqualTensor
+ truePredicateTensor: numClassesTensor
+ falsePredicateTensor: targetTensor
+ name: @"predicateTensor"];
}
-
- float onValue = -1.0f;
-
- MPSGraphTensor *oneHotTensor;
-
- oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
- depth:numClasses
- axis:1
- dataType:inputTensor.dataType
- onValue:onValue
- offValue:0.0f
- name:nil];
-
- if(isWeightsArrayValid)
- {
- oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
- secondaryTensor:weightTensor
- name:@"scaleByWeightTensor"];
+ MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor
+ depth: numClasses
+ axis: 1
+ dataType: inputTensor.dataType
+ onValue: -1.0f
+ offValue: 0.0f
+ name: nil];
+ if (isWeightsArrayValid) {
+ oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor: oneHotTensor
+ secondaryTensor: weightTensor
+ name: @"scaleByWeightTensor"];
}
-
- if(reduction == Reduction::Mean)
- {
- oneHotTensor = [mpsGraph divisionNoNaNWithPrimaryTensor:oneHotTensor
- secondaryTensor:totalWeightTensor
- name:@"divisionTensor"];
+ if (reduction == Reduction::Mean) {
+ oneHotTensor = [mpsGraph divisionNoNaNWithPrimaryTensor: oneHotTensor
+ secondaryTensor: totalWeightTensor
+ name: @"divisionTensor"];
}
-
- MPSGraphTensor* gradInputTensor = oneHotTensor;
-
+ MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: oneHotTensor
+ secondaryTensor: gradOutputTensor
+ name: nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->targetTensor_ = targetTensor;
newCachedGraph->weightTensor_ = weightTensor;
newCachedGraph->totalWeightTensor_ = totalWeightTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
-
+ newCachedGraph->gradOutputTensor_ = gradOutputTensor;
}
return newCachedGraph;
});
- cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
- auto targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target);
+ auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
+ auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
+ auto targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target);
Placeholder weightPlaceholder = Placeholder();
- if(isWeightsArrayValid)
- weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
- auto totalWeightPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight);
+ if(isWeightsArrayValid) {
+ weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
+ }
+ auto totalWeightPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight);
auto gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
- NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 4] autorelease];
+ NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData();
feeds[totalWeightPlaceholder.getMPSGraphTensor()] = totalWeightPlaceholder.getMPSGraphTensorData();
+ feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
- if(isWeightsArrayValid)
+ if (isWeightsArrayValid) {
feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
-
+ }
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
- runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
-
- return;
}
void nllnd_loss_forward_impl
@@ -907,132 +872,101 @@
double beta,
Tensor& grad_input)
{
- struct CachedGraph : public MPSCachedGraph
- {
+ if (grad_input.numel() == 0) {
+ return;
+ }
+ TORCH_CHECK(beta >= 0, "smooth_l1_loss_backward does not support negative values for beta.");
+
+ struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *targetTensor_ = nil;
MPSGraphTensor *gradInputTensor_ = nil;
+ MPSGraphTensor* gradOutputTensor_ = nil;
};
- MPSGraphCache *cache_ = MPSGraphCache::getInstance();
-
- MPSStream *stream= getCurrentMPSStream();
-
@autoreleasepool {
+ string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":"
+ + reductionToString(reduction) + ":" + to_string(beta);
- auto numClasses = grad_input.sizes()[1];
- MPSShape* input_shape = getMPSShape(input);
- NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
-
- string key = "smooth_l1_loss_backward_impl:" + to_string(numClasses) + ":" +
- reductionToString(reduction) + ":" +
- [ns_shape_key UTF8String] + ":" +
- to_string(beta) + ":" +
- getMPSTypeString(input.scalar_type()) + ":" +
- getMPSTypeString(target.scalar_type());
- CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
- if(!cachedGraph) {
- MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+ MPSGraphCache *cache_ = MPSGraphCache::getInstance();
+ CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+ if (!cachedGraph) {
+ cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
- auto numElements = input.numel();
-
MPSGraph *mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor *inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()));
- MPSGraphTensor *targetTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(target.scalar_type()));
+ MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
+ MPSGraphTensor *targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
+ MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta
dataType: MPSDataTypeFloat32];
-
- MPSGraphTensor *numelTensor = [mpsGraph constantWithScalar: numElements
- dataType: MPSDataTypeFloat32];
-
// xn - yn
MPSGraphTensor *diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor
secondaryTensor: targetTensor
name: nil];
-
// | xn - yn |
MPSGraphTensor *diffAbsTensor = [mpsGraph absoluteWithTensor: diffTensor
name: nil];
-
// | xn - yn | < beta
MPSGraphTensor *diffAbsLessThanBetaTensor = [mpsGraph lessThanWithPrimaryTensor: diffAbsTensor
secondaryTensor: betaTensor
name: nil];
-
// ( xn - yn ) / beta
MPSGraphTensor *truePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor
secondaryTensor: betaTensor
name: nil];
-
// ( x - y ) / | x - y |
- MPSGraphTensor *falsePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor
- secondaryTensor: diffAbsTensor
- name: nil];
+ MPSGraphTensor *falsePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor
+ secondaryTensor: diffAbsTensor
+ name: nil];
MPSGraphTensor *lossTensor = [mpsGraph selectWithPredicateTensor: diffAbsLessThanBetaTensor
- truePredicateTensor: truePredicateTensor
- falsePredicateTensor: falsePredicateTensor
- name: @"lossTensor"];
-
+ truePredicateTensor: truePredicateTensor
+ falsePredicateTensor: falsePredicateTensor
+ name: @"lossTensor"];
MPSGraphTensor *outputTensor = lossTensor;
- if (reduction == Reduction::Mean)
- {
- outputTensor = [mpsGraph divisionWithPrimaryTensor: lossTensor
- secondaryTensor: numelTensor
- name: nil];
+ if (reduction == Reduction::Mean) {
+ MPSGraphTensor *numelTensor = [mpsGraph constantWithScalar: (double) input.numel()
+ dataType: MPSDataTypeFloat32];
+ outputTensor = [mpsGraph divisionWithPrimaryTensor: lossTensor
+ secondaryTensor: numelTensor
+ name: nil];
}
-
- MPSGraphTensor *gradInputTensor = outputTensor;
-
+ MPSGraphTensor *gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: outputTensor
+ secondaryTensor: gradOutputTensor
+ name: nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->targetTensor_ = targetTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
-
+ newCachedGraph->gradOutputTensor_ = gradOutputTensor;
}
return newCachedGraph;
});
- cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
+ Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
- targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData()
+ targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData(),
+ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
- runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
}
-void smooth_l1_loss_backward_template(
- const Tensor& grad_output,
- const Tensor& input,
- const Tensor& target,
- int64_t reduction,
- double beta,
- Tensor& grad_input)
-{
- TORCH_CHECK(beta >= 0, "smooth_l1_loss_backward does not support negative values for beta.");
- TORCH_CHECK(input.is_mps());
- TORCH_CHECK(target.is_mps());
-
- smooth_l1_loss_backward_impl(
- grad_output, input, target, reduction, beta, grad_input
- );
-}
-
} // namespace mps
// APIs exposed to at::native scope
@@ -1390,8 +1324,10 @@
int64_t reduction,
double beta,
Tensor& grad_input) {
- mps::smooth_l1_loss_backward_template(
+
+ mps::smooth_l1_loss_backward_impl(
grad_output, input, target, reduction, beta, grad_input);
+
return grad_input;
}
diff --git a/test/test_mps.py b/test/test_mps.py
index e258c3f..e0a0527 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2449,13 +2449,15 @@
num_channels = input_size[1]
target_size = (input_size[0], ) + tuple(input_size[2:])
target = torch.randint(num_channels, target_size, device='cpu')
+ weights = torch.randn(num_channels)
# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')
+ weights_mps = weights.to("mps")
- output_cpu = F.nll_loss(input, target, reduction=reduction)
- output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
+ output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
+ output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
self.assertEqual(output_cpu, output_mps.to('cpu'))
output_cpu.sum().backward()
@@ -8369,6 +8371,7 @@
'nn.functional.max_pool2d': ['f32'],
'max_pool2d_with_indices_backward': ['f32'],
'nn.functional.mse_loss': ['f16', 'f32'],
+ 'nn.functional.nll_loss': ['f32'],
'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nn.functional.padconstant': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nn.functional.padreflect': ['f32'],
@@ -8584,6 +8587,7 @@
'nn.functional.max_pool1d': ['f32'],
'nn.functional.max_pool2d': ['f32'],
'nn.functional.mse_loss': ['f32'],
+ 'nn.functional.nll_loss': ['f32'],
'nn.functional.pad': ['f16', 'f32', 'i16', 'i32', 'i64'],
'nn.functional.pairwise_distance': ['f16', 'f32'],
'nn.functional.poisson_nll_loss': ['f32'],
@@ -8595,6 +8599,7 @@
'nn.functional.softmin': ['f32'],
'nn.functional.softplus': ['f32'],
'nn.functional.softsign': ['f16', 'f32'],
+ 'nn.functional.smooth_l1_loss': ['f32'],
'nn.functional.threshold': ['f32'],
'nn.functional.triplet_margin_loss': ['f32'],
'nn.functional.triplet_margin_with_distance_loss': ['f32'],
@@ -8655,7 +8660,6 @@
'masked.sum': [torch.bool],
# Functions that hard crash
- 'nn.functional.nll_loss': [torch.float32],
'std': [torch.float16],
'stft': [torch.float32], 'var': [torch.float16],
# + forward when requires_grad=True or running backward