Add C++ API clip_grad_value_ for nn:utils (#28736)

Summary:
Adds C++ API clip_grad_value_ for torch::nn:utils module.
Also, fix the for indent level error in the original test/test_nn.py.

Issue: https://github.com/pytorch/pytorch/issues/25883

Reviewer: yf225
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28736

Differential Revision: D18263807

Pulled By: yf225

fbshipit-source-id: 29282450bd2099df16925e1d0edd3d933f6eeb9b
diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp
index fa0693e..f60d958 100644
--- a/test/cpp/api/nn_utils.cpp
+++ b/test/cpp/api/nn_utils.cpp
@@ -44,7 +44,7 @@
 
   std::vector<torch::Tensor> grads = {
       torch::arange(1.0, 101).view({10, 10}),
-      torch::ones(10).div(1000),
+      torch::ones({10}).div(1000),
   };
   std::vector<float> norm_types = {
       0.5,
@@ -101,3 +101,42 @@
     ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
   }
 }
+
+TEST_F(NNUtilsTest, ClipGradValue) {
+  auto linear_layer = Linear(10, 10);
+  float clip_value = 2.5;
+
+  torch::Tensor grad_w = torch::arange(-50., 50).view({10, 10}).div_(5);
+  torch::Tensor grad_b = torch::ones({10}).mul_(2);
+  std::vector<std::vector<torch::Tensor>> grad_lists = {
+      {grad_w, grad_b}, {grad_w, torch::Tensor()}};
+  for (auto grad_list : grad_lists) {
+    for (int i = 0; i < grad_list.size(); i++) {
+      auto p = linear_layer->parameters()[i];
+      auto g = grad_list[i];
+      p.grad() = g.defined() ? g.clone().view_as(p.data()) : g;
+    }
+
+    auto layer_params = linear_layer->parameters();
+    utils::clip_grad_value_(layer_params, clip_value);
+    for (int i = 0; i < layer_params.size(); i++) {
+      if (layer_params[i].grad().defined()) {
+        ASSERT_LE(
+            layer_params[i].grad().data().max().item().toFloat(), clip_value);
+        ASSERT_GE(
+            layer_params[i].grad().data().min().item().toFloat(), -clip_value);
+      }
+    }
+  }
+
+  // Should accept a single Tensor as input
+  auto p1 = torch::randn({10, 10});
+  auto p2 = torch::randn({10, 10});
+  auto g = torch::arange(-50., 50).view({10, 10}).div_(5);
+  p1.grad() = g.clone();
+  p2.grad() = g.clone();
+  utils::clip_grad_value_(p1, clip_value);
+  std::vector<torch::Tensor> params = {p2};
+  utils::clip_grad_value_(params, clip_value);
+  ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
+}
diff --git a/test/test_nn.py b/test/test_nn.py
index 44b01e8..bc3464e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1877,10 +1877,10 @@
             for p, g in zip(l.parameters(), grad_list):
                 p._grad = g.clone().view_as(p.data) if g is not None else g
 
-        clip_grad_value_(l.parameters(), clip_value)
-        for p in filter(lambda p: p.grad is not None, l.parameters()):
-            self.assertLessEqual(p.grad.data.max(), clip_value)
-            self.assertGreaterEqual(p.grad.data.min(), -clip_value)
+            clip_grad_value_(l.parameters(), clip_value)
+            for p in filter(lambda p: p.grad is not None, l.parameters()):
+                self.assertLessEqual(p.grad.data.max(), clip_value)
+                self.assertGreaterEqual(p.grad.data.min(), -clip_value)
 
         # Should accept a single Tensor as input
         p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
diff --git a/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/torch/csrc/api/include/torch/nn/utils/clip_grad.h
index fea0fe0..dcf5668 100644
--- a/torch/csrc/api/include/torch/nn/utils/clip_grad.h
+++ b/torch/csrc/api/include/torch/nn/utils/clip_grad.h
@@ -57,6 +57,28 @@
   return clip_grad_norm_(params, max_norm, norm_type);
 }
 
+// Clips gradient of an iterable of parameters at specified value.
+// Gradients are modified in-place.
+// See https://pytorch.org/docs/stable/nn.html#clip-grad-value
+// for more details about this module.
+inline void clip_grad_value_(
+    std::vector<Tensor>& parameters,
+    float clip_value) {
+
+  for (const auto& param : parameters) {
+    if (param.grad().defined()) {
+      param.grad().data().clamp_(-clip_value, clip_value);
+    }
+  }
+}
+
+// A wrapper around clip_grad_value_ that allows us to call the function with a
+// single Tensor.
+inline void clip_grad_value_(Tensor& parameters, float clip_value) {
+  std::vector<Tensor> params = {parameters};
+  clip_grad_value_(params, clip_value);
+}
+
 } // namespace utils
 } // namespace nn
 } // namespace torch