| #include <torch/extension.h> | |
| struct Doubler { | |
| Doubler(int A, int B) { | |
| tensor_ = | |
| torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true)); | |
| } | |
| torch::Tensor forward() { | |
| return tensor_ * 2; | |
| } | |
| torch::Tensor get() const { | |
| return tensor_; | |
| } | |
| private: | |
| torch::Tensor tensor_; | |
| }; |