[MPS] Fix gradient issues with NLL and Smooth_L1 loss ops (#94226)

- Fix correctness issues with nll_loss_backward(), smooth_l1_loss_backward() and cross_entropy_backward() by taking grad_output into account when computing those loss ops
- Add numel()==0 check to prevent crashes
- Clean up and formatting
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94226
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm
index 5e46582..086e594 100644
--- a/aten/src/ATen/native/mps/operations/LossOps.mm
+++ b/aten/src/ATen/native/mps/operations/LossOps.mm
@@ -1,13 +1,6 @@
 //  Copyright © 2022 Apple Inc.
 
 #include <ATen/native/mps/OperationUtils.h>
-#include <ATen/mps/MPSStream.h>
-#include <objc/NSObjCRuntime.h>
-#include <torch/library.h>
-
-#ifdef __OBJC__
-#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
-#endif
 
 namespace at::native {
 namespace mps {
@@ -37,12 +30,6 @@
     }
 }
 
-// MSELoss
-void mse_loss_out_impl(const Tensor& input, const Tensor& target,
-                          int64_t reduction, const Tensor& output, const string op_name)
-{
-}
-
 Tensor& mse_loss_backward_out_impl(const Tensor& grad_output, const Tensor& input, const Tensor& target,
                                    int64_t reduction, Tensor& grad_input, const string op_name)
 {
@@ -313,163 +300,141 @@
 
 // NLLLoss
 void nllnd_loss_backward_impl(
-Tensor& grad_input_arg,
-const Tensor& grad_output,
-const Tensor& input_arg,
-const Tensor& target_arg,
-const Tensor& weight,
-int64_t reduction,
-int64_t ignore_index,
-const Tensor& total_weight,
-bool is2D)
-{
-    // Empty output
-    if(grad_input_arg.numel() == 0)
+    Tensor& grad_input_arg,
+    const Tensor& grad_output_arg,
+    const Tensor& input_arg,
+    const Tensor& target_arg,
+    const Tensor& weight_arg,
+    int64_t reduction,
+    int64_t ignore_index,
+    const Tensor& total_weight,
+    bool is2D) {
+
+    if (grad_input_arg.numel() == 0) {
         return;
-
-    MPSStream* stream = getCurrentMPSStream();
-
-    struct CachedGraph : public MPSCachedGraph
-    {
+    }
+    struct CachedGraph : public MPSCachedGraph {
         CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
         MPSGraphTensor* inputTensor_ = nil;
         MPSGraphTensor* targetTensor_ = nil;
         MPSGraphTensor* weightTensor_ = nil;
         MPSGraphTensor* totalWeightTensor_ = nil;
         MPSGraphTensor* gradInputTensor_ = nil;
+        MPSGraphTensor* gradOutputTensor_ = nil;
     };
-
-    MPSGraphCache* cache_ = MPSGraphCache::getInstance();
-
+    bool isWeightsArrayValid = weight_arg.defined() && weight_arg.numel() > 0;
+    int64_t channel_dim = grad_input_arg.dim() < 2 ? 0 : 1;
     auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
     auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
     auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg;
+    auto numClasses = grad_input.sizes()[1];
+    auto weight = weight_arg;
+    auto grad_output = grad_output_arg;
 
+    if (isWeightsArrayValid) {
+        std::vector<int64_t> weightShape(input.dim(), 1);
+        weightShape[1] = input.size(1);
+        weight = weight_arg.view(weightShape);
+    }
+    if (grad_output_arg.dim() < grad_input.dim() && grad_output_arg.dim() > 0) {
+      grad_output = grad_output_arg.unsqueeze(channel_dim);
+    }
     @autoreleasepool {
+        string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight})
+                                           + to_string(numClasses) + ":" + to_string(ignore_index) + ":"
+                                           + to_string(isWeightsArrayValid) + ":" + reductionToString(reduction);
 
-        auto numClasses = grad_input.sizes()[1];
-        bool isWeightsArrayValid = (weight.numel() > 0);
-
-        MPSShape* input_shape = getMPSShape(input);
-        MPSShape* target_shape = getMPSShape(target);
-        MPSShape* weight_shape = getMPSShape(weight);
-        MPSShape* total_weight_shape = getMPSShape(total_weight);
-
-        NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
-
-        string key = "nllnd_loss_backward_impl:" + to_string(numClasses) + ":" +
-                                                   to_string(ignore_index) + ":" +
-                                                   to_string(isWeightsArrayValid) + ":" +
-                                                   reductionToString(reduction) + ":" +
-                                                   [ns_shape_key UTF8String] + ":" +
-                                                   getMPSTypeString(input.scalar_type()) + ":" +
-                                                   getMPSTypeString(target.scalar_type()) + ":" +
-                                                   getMPSTypeString(weight.scalar_type()) + ":" +
-                                                   getMPSTypeString(total_weight.scalar_type());
-        CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-
+        MPSGraphCache* cache_ = MPSGraphCache::getInstance();
+        CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
         if(!cachedGraph) {
-            MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+            cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
 
                 CachedGraph *newCachedGraph = nil;
-
                 @autoreleasepool {
-
                     MPSGraph* mpsGraph = make_mps_graph();
                     newCachedGraph = new CachedGraph(mpsGraph);
 
-                    MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape);
-                    MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target.scalar_type()), target_shape);
+                    MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
+                    MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
                     MPSGraphTensor* weightTensor = nil;
-                    if(isWeightsArrayValid)
-                        weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(weight.scalar_type()), weight_shape);
-                    MPSGraphTensor* totalWeightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(total_weight.scalar_type()), total_weight_shape);
+                    if (isWeightsArrayValid) {
+                        weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
+                    }
+                    MPSGraphTensor* totalWeightTensor = mpsGraphRankedPlaceHolder(mpsGraph, total_weight);
+                    MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
 
                     MPSGraphTensor *udpatedTargetTensor = targetTensor;
 
                     // Replace ignored_index with length depth + 1 so that oneHotAPI ignores it
-                    if(ignore_index != -100)
-                    {
-                        MPSGraphTensor *mpsGraphIndexTensor = [mpsGraph constantWithScalar: ignore_index
-                                                                                  dataType: MPSDataTypeInt64];
-                        MPSGraphTensor *mpsGraphDepthPlusOneTensor = [mpsGraph constantWithScalar: (numClasses + 1)
-                                                                                  dataType: MPSDataTypeInt64];
-
-                        // Equal tensor
-                        MPSGraphTensor* mpsGraphIsEqualTensor = [mpsGraph equalWithPrimaryTensor: targetTensor
-                                                                                 secondaryTensor: mpsGraphIndexTensor
-                                                                                            name: @"isEqualTensor"];
-
-                        udpatedTargetTensor = [mpsGraph selectWithPredicateTensor: mpsGraphIsEqualTensor
-                                                          truePredicateTensor: mpsGraphDepthPlusOneTensor
-                                                         falsePredicateTensor: targetTensor
-                                                                         name: @"predicateTensor"];
+                    if (ignore_index != -100) {
+                        MPSGraphTensor *ignoreIndexTensor = [mpsGraph constantWithScalar: ignore_index
+                                                                                dataType: MPSDataTypeInt64];
+                        MPSGraphTensor *numClassesTensor  = [mpsGraph constantWithScalar: (numClasses + 1)
+                                                                                dataType: MPSDataTypeInt64];
+                        MPSGraphTensor* isEqualTensor = [mpsGraph equalWithPrimaryTensor: targetTensor
+                                                                         secondaryTensor: ignoreIndexTensor
+                                                                                    name: @"isEqualTensor"];
+                        udpatedTargetTensor = [mpsGraph selectWithPredicateTensor: isEqualTensor
+                                                              truePredicateTensor: numClassesTensor
+                                                             falsePredicateTensor: targetTensor
+                                                                             name: @"predicateTensor"];
                     }
-
-                    float onValue = -1.0f;
-
-                    MPSGraphTensor *oneHotTensor;
-
-                    oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
-                                                               depth:numClasses
-                                                                axis:1
-                                                            dataType:inputTensor.dataType
-                                                             onValue:onValue
-                                                            offValue:0.0f
-                                                                name:nil];
-
-                    if(isWeightsArrayValid)
-                    {
-                        oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
-                                                                 secondaryTensor:weightTensor
-                                                                            name:@"scaleByWeightTensor"];
+                    MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor
+                                                                               depth: numClasses
+                                                                                axis: 1
+                                                                            dataType: inputTensor.dataType
+                                                                             onValue: -1.0f
+                                                                            offValue: 0.0f
+                                                                                name: nil];
+                    if (isWeightsArrayValid) {
+                        oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor: oneHotTensor
+                                                                 secondaryTensor: weightTensor
+                                                                            name: @"scaleByWeightTensor"];
                     }
-
-                    if(reduction == Reduction::Mean)
-                    {
-                        oneHotTensor = [mpsGraph divisionNoNaNWithPrimaryTensor:oneHotTensor
-                                                                secondaryTensor:totalWeightTensor
-                                                                           name:@"divisionTensor"];
+                    if (reduction == Reduction::Mean) {
+                        oneHotTensor = [mpsGraph divisionNoNaNWithPrimaryTensor: oneHotTensor
+                                                                secondaryTensor: totalWeightTensor
+                                                                           name: @"divisionTensor"];
                     }
-
-                    MPSGraphTensor* gradInputTensor = oneHotTensor;
-
+                    MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: oneHotTensor
+                                                                                secondaryTensor: gradOutputTensor
+                                                                                           name: nil];
                     newCachedGraph->inputTensor_ = inputTensor;
                     newCachedGraph->targetTensor_ = targetTensor;
                     newCachedGraph->weightTensor_ = weightTensor;
                     newCachedGraph->totalWeightTensor_ = totalWeightTensor;
                     newCachedGraph->gradInputTensor_ = gradInputTensor;
-
+                    newCachedGraph->gradOutputTensor_ = gradOutputTensor;
                 }
                 return newCachedGraph;
             });
-            cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
         }
 
-        auto inputPlaceholder   = Placeholder(cachedGraph->inputTensor_, input);
-        auto targetPlaceholder   = Placeholder(cachedGraph->targetTensor_, target);
+        auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
+        auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
+        auto targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target);
         Placeholder weightPlaceholder = Placeholder();
-        if(isWeightsArrayValid)
-            weightPlaceholder  = Placeholder(cachedGraph->weightTensor_, weight);
-        auto totalWeightPlaceholder   = Placeholder(cachedGraph->totalWeightTensor_, total_weight);
+        if(isWeightsArrayValid) {
+            weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
+        }
+        auto totalWeightPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight);
         auto gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
 
-        NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 4] autorelease];
+        NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
         feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
         feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData();
         feeds[totalWeightPlaceholder.getMPSGraphTensor()] = totalWeightPlaceholder.getMPSGraphTensorData();
+        feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
 
-        if(isWeightsArrayValid)
+        if (isWeightsArrayValid) {
             feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
-
+        }
         NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
             gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
         };
 
-        runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+        runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
     }
-
-    return;
 }
 
 void nllnd_loss_forward_impl
@@ -907,132 +872,101 @@
     double beta,
     Tensor& grad_input)
 {
- struct CachedGraph : public MPSCachedGraph
-  {
+  if (grad_input.numel() == 0) {
+    return;
+  }
+  TORCH_CHECK(beta >= 0, "smooth_l1_loss_backward does not support negative values for beta.");
+
+  struct CachedGraph : public MPSCachedGraph {
     CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
     MPSGraphTensor *inputTensor_ = nil;
     MPSGraphTensor *targetTensor_ = nil;
     MPSGraphTensor *gradInputTensor_ = nil;
+    MPSGraphTensor* gradOutputTensor_ = nil;
   };
 
- MPSGraphCache *cache_ = MPSGraphCache::getInstance();
-
-  MPSStream *stream= getCurrentMPSStream();
-
   @autoreleasepool {
+    string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":"
+                                           + reductionToString(reduction) + ":" + to_string(beta);
 
-    auto numClasses = grad_input.sizes()[1];
-    MPSShape* input_shape = getMPSShape(input);
-    NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
-
-    string key = "smooth_l1_loss_backward_impl:" + to_string(numClasses) + ":" +
-                                                   reductionToString(reduction) + ":" +
-                                                   [ns_shape_key UTF8String] + ":" +
-                                                   to_string(beta) + ":" +
-                                                   getMPSTypeString(input.scalar_type()) + ":" +
-                                                   getMPSTypeString(target.scalar_type());
-    CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-    if(!cachedGraph) {
-      MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+    MPSGraphCache *cache_ = MPSGraphCache::getInstance();
+    CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+    if (!cachedGraph) {
+      cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
 
         CachedGraph *newCachedGraph = nil;
 
         @autoreleasepool {
-          auto numElements = input.numel();
-
           MPSGraph *mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor *inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()));
-          MPSGraphTensor *targetTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(target.scalar_type()));
+          MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
+          MPSGraphTensor *targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
+          MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
 
           MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta
                                                            dataType: MPSDataTypeFloat32];
-
-          MPSGraphTensor *numelTensor = [mpsGraph constantWithScalar: numElements
-                                                            dataType: MPSDataTypeFloat32];
-
           // xn - yn
           MPSGraphTensor *diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor
                                                               secondaryTensor: targetTensor
                                                                          name: nil];
-
           // | xn - yn |
           MPSGraphTensor *diffAbsTensor = [mpsGraph absoluteWithTensor: diffTensor
                                                                   name: nil];
-
           // | xn - yn | < beta
           MPSGraphTensor *diffAbsLessThanBetaTensor = [mpsGraph lessThanWithPrimaryTensor: diffAbsTensor
                                                                           secondaryTensor: betaTensor
                                                                                      name: nil];
-
           // ( xn - yn ) / beta
           MPSGraphTensor *truePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor
                                                                     secondaryTensor: betaTensor
                                                                                name: nil];
-
           // ( x - y ) / | x - y |
-           MPSGraphTensor *falsePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor
-                                                                      secondaryTensor: diffAbsTensor
-                                                                                 name: nil];
+          MPSGraphTensor *falsePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor
+                                                                     secondaryTensor: diffAbsTensor
+                                                                                name: nil];
 
           MPSGraphTensor *lossTensor = [mpsGraph selectWithPredicateTensor: diffAbsLessThanBetaTensor
-                                                            truePredicateTensor: truePredicateTensor
-                                                           falsePredicateTensor: falsePredicateTensor
-                                                                           name: @"lossTensor"];
-
+                                                       truePredicateTensor: truePredicateTensor
+                                                      falsePredicateTensor: falsePredicateTensor
+                                                                      name: @"lossTensor"];
           MPSGraphTensor *outputTensor = lossTensor;
-          if (reduction == Reduction::Mean)
-          {
-              outputTensor = [mpsGraph divisionWithPrimaryTensor: lossTensor
-                                                 secondaryTensor: numelTensor
-                                                            name: nil];
+          if (reduction == Reduction::Mean) {
+            MPSGraphTensor *numelTensor = [mpsGraph constantWithScalar: (double) input.numel()
+                                                              dataType: MPSDataTypeFloat32];
+            outputTensor = [mpsGraph divisionWithPrimaryTensor: lossTensor
+                                               secondaryTensor: numelTensor
+                                                          name: nil];
           }
-
-          MPSGraphTensor *gradInputTensor = outputTensor;
-
+          MPSGraphTensor *gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: outputTensor
+                                                                      secondaryTensor: gradOutputTensor
+                                                                                 name: nil];
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->targetTensor_ = targetTensor;
           newCachedGraph->gradInputTensor_ = gradInputTensor;
-
+          newCachedGraph->gradOutputTensor_ = gradOutputTensor;
         }
         return newCachedGraph;
       });
-      cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
     Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
     Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target);
     Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
+    Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
 
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
       inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
-      targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData()
+      targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData(),
+      gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData()
     };
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
       gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
     };
 
-    runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
   }
 }
 
-void smooth_l1_loss_backward_template(
-    const Tensor& grad_output,
-    const Tensor& input,
-    const Tensor& target,
-    int64_t reduction,
-    double beta,
-    Tensor& grad_input)
-{
-  TORCH_CHECK(beta >= 0, "smooth_l1_loss_backward does not support negative values for beta.");
-  TORCH_CHECK(input.is_mps());
-  TORCH_CHECK(target.is_mps());
-
-  smooth_l1_loss_backward_impl(
-      grad_output, input, target, reduction, beta, grad_input
-  );
-}
-
 } // namespace mps
 
 // APIs exposed to at::native scope
@@ -1390,8 +1324,10 @@
     int64_t reduction,
     double beta,
     Tensor& grad_input) {
-  mps::smooth_l1_loss_backward_template(
+
+  mps::smooth_l1_loss_backward_impl(
       grad_output, input, target, reduction, beta, grad_input);
+
   return grad_input;
 }
 
diff --git a/test/test_mps.py b/test/test_mps.py
index e258c3f..e0a0527 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2449,13 +2449,15 @@
         num_channels = input_size[1]
         target_size = (input_size[0], ) + tuple(input_size[2:])
         target = torch.randint(num_channels, target_size, device='cpu')
+        weights = torch.randn(num_channels)
 
         # MPS
         input_mps = input.detach().clone().to('mps').requires_grad_()
         target_mps = target.detach().clone().to('mps')
+        weights_mps = weights.to("mps")
 
-        output_cpu = F.nll_loss(input, target, reduction=reduction)
-        output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
+        output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
+        output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
         self.assertEqual(output_cpu, output_mps.to('cpu'))
 
         output_cpu.sum().backward()
@@ -8369,6 +8371,7 @@
         'nn.functional.max_pool2d': ['f32'],
         'max_pool2d_with_indices_backward': ['f32'],
         'nn.functional.mse_loss': ['f16', 'f32'],
+        'nn.functional.nll_loss': ['f32'],
         'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'nn.functional.padconstant': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'nn.functional.padreflect': ['f32'],
@@ -8584,6 +8587,7 @@
         'nn.functional.max_pool1d': ['f32'],
         'nn.functional.max_pool2d': ['f32'],
         'nn.functional.mse_loss': ['f32'],
+        'nn.functional.nll_loss': ['f32'],
         'nn.functional.pad': ['f16', 'f32', 'i16', 'i32', 'i64'],
         'nn.functional.pairwise_distance': ['f16', 'f32'],
         'nn.functional.poisson_nll_loss': ['f32'],
@@ -8595,6 +8599,7 @@
         'nn.functional.softmin': ['f32'],
         'nn.functional.softplus': ['f32'],
         'nn.functional.softsign': ['f16', 'f32'],
+        'nn.functional.smooth_l1_loss': ['f32'],
         'nn.functional.threshold': ['f32'],
         'nn.functional.triplet_margin_loss': ['f32'],
         'nn.functional.triplet_margin_with_distance_loss': ['f32'],
@@ -8655,7 +8660,6 @@
         'masked.sum': [torch.bool],
 
         # Functions that hard crash
-        'nn.functional.nll_loss': [torch.float32],
         'std': [torch.float16],
         'stft': [torch.float32], 'var': [torch.float16],
         # + forward when requires_grad=True or running backward