blob: 2b22dca1284cd54147870a47e4f51d504788c2ce [file] [log] [blame]
#include <torch/torch.h>
struct Doubler {
Doubler(int A, int B) {
tensor_ = at::ones({A, B}, torch::CPU(at::kDouble));
torch::set_requires_grad(tensor_, true);
}
at::Tensor forward() {
return tensor_ * 2;
}
at::Tensor get() const {
return tensor_;
}
private:
at::Tensor tensor_;
};