blob: d9e6aaea8c34651ceaa5d4aee1e3a21708ac8849 [file] [log] [blame]
#include <torch/extension.h>
struct Doubler {
Doubler(int A, int B) {
tensor_ = at::ones({A, B}, torch::CPU(at::kDouble));
torch::set_requires_grad(tensor_, true);
}
at::Tensor forward() {
return tensor_ * 2;
}
at::Tensor get() const {
return tensor_;
}
private:
at::Tensor tensor_;
};