| #include <torch/torch.h> | |
| struct Doubler { | |
| Doubler(int A, int B) { | |
| tensor_ = at::ones({A, B}, torch::CPU(at::kDouble)); | |
| torch::set_requires_grad(tensor_, true); | |
| } | |
| at::Tensor forward() { | |
| return tensor_ * 2; | |
| } | |
| at::Tensor get() const { | |
| return tensor_; | |
| } | |
| private: | |
| at::Tensor tensor_; | |
| }; |