[docs] Fix ScalarTensor __repr__ in Extending PyTorch example (#86330)
This PR fixes the __repr__ of the `ScalarTensor` class in the Extending PyTorch example to correspond with the class name instead of `DiagonalTensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86330
Approved by: https://github.com/bdhirsh
diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst
index 645d687..f7cc413 100644
--- a/docs/source/notes/extending.rst
+++ b/docs/source/notes/extending.rst
@@ -372,7 +372,7 @@
self._value = value
def __repr__(self):
- return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
+ return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
@@ -409,7 +409,7 @@
self._value = value
def __repr__(self):
- return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
+ return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
@@ -494,7 +494,7 @@
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
- DiagonalTensor(N=2, value=4)
+ ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],