Remove return value for __exit__ (#32997)

Summary:
When an error is raised and `__exit__` in a context manager returns `True`, the error is suppressed; otherwise the error is raised. No return value should be given to maintain the default behavior of context manager.

Fixes https://github.com/pytorch/pytorch/issues/32639. The `get_lr` function was overridden with a function taking an epoch parameter, which is not allowed. However, the relevant error was not being raised.

```python
In [1]: import torch
   ...:
   ...: class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
   ...:     def __init__(self, optimizer, gamma, milestones, last_epoch = -1):
   ...:         self.init_lr = [group['lr'] for group in optimizer.param_groups]
   ...:         self.gamma = gamma
   ...:         self.milestones = milestones
   ...:         super().__init__(optimizer, last_epoch)
   ...:
   ...:     def get_lr(self, step):
   ...:         global_step = self.last_epoch #iteration number in pytorch
   ...:         gamma_power = ([0] + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m])[-1]
   ...:         return [init_lr * (self.gamma ** gamma_power) for init_lr in self.init_lr]
   ...:
   ...: optimizer = torch.optim.SGD([torch.rand(1)], lr = 1)
   ...: scheduler = MultiStepLR(optimizer, gamma = 1, milestones = [10, 20])
```
```
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-7fad6ba050b0> in <module>
     14
     15 optimizer = torch.optim.SGD([torch.rand(1)], lr = 1)
---> 16 scheduler = MultiStepLR(optimizer, gamma = 1, milestones = [10, 20])

<ipython-input-1-7fad6ba050b0> in __init__(self, optimizer, gamma, milestones, last_epoch)
      6         self.gamma = gamma
      7         self.milestones = milestones
----> 8         super().__init__(optimizer, last_epoch)
      9
     10     def get_lr(self, step):

~/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py in __init__(self, optimizer, last_epoch)
     75         self._step_count = 0
     76
---> 77         self.step()
     78
     79     def state_dict(self):

~/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/optim/lr_scheduler.py in step(self, epoch)
    141                 print("1a")
    142                 # try:
--> 143                 values = self.get_lr()
    144                 # except TypeError:
    145                     # raise RuntimeError

TypeError: get_lr() missing 1 required positional argument: 'step'
```

May be related to https://github.com/pytorch/pytorch/issues/32898.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32997

Differential Revision: D19737731

Pulled By: vincentqb

fbshipit-source-id: 5cf84beada69b91f91e36b20c3278e9920343655
diff --git a/test/test_optim.py b/test/test_optim.py
index 17a0b6d..45de0d1 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -509,6 +509,24 @@
             [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
             lr=0.05)
 
+    def test_error_when_getlr_has_epoch(self):
+        class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
+            def __init__(self, optimizer, gamma, milestones, last_epoch=-1):
+                self.init_lr = [group['lr'] for group in optimizer.param_groups]
+                self.gamma = gamma
+                self.milestones = milestones
+                super().__init__(optimizer, last_epoch)
+
+            def get_lr(self, step):
+                global_step = self.last_epoch
+                gamma_power = ([0] + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m])[-1]
+                return [init_lr * (self.gamma ** gamma_power) for init_lr in self.init_lr]
+
+        optimizer = torch.optim.SGD([torch.rand(1)], lr=1)
+
+        with self.assertRaises(TypeError):
+            scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20])
+
     def test_no_cyclic_references(self):
         import gc
         param = Variable(torch.Tensor(10), requires_grad=True)
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index 3dffa39..461cee4 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -134,7 +134,6 @@
 
             def __exit__(self, type, value, traceback):
                 self.o._get_lr_called_within_step = False
-                return self
 
         with _enable_get_lr_call(self):
             if epoch is None: