blob: 3fd57399cc379ad11b03ee295c26e2b651f8c107 [file] [log] [blame]
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/Dispatch.h"
using std::cout;
using namespace at;
constexpr auto Float = ScalarType::Float;
constexpr auto Double = ScalarType::Float;
template<typename scalar_type>
struct Foo {
static void CPU(const Type & t, Tensor a, Tensor b) {
scalar_type s = 1;
cout << "hello, dispatch: " << t.toString() << s << "\n";
auto data = (scalar_type*)a.data_ptr();
}
static void CUDA(const Type & t, Tensor a, Tensor b) {
}
};
template<>
struct Foo<Half> {
static void CPU(const Type & t, Tensor a, Tensor b) {}
static void CUDA(const Type & t, Tensor a, Tensor b) {}
};
int main() {
Scalar what = 257;
Scalar bar = 3.0;
Half h = bar.toHalf();
Scalar h2 = h;
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " " << bar.toDouble() << " " << what.isIntegral() << "\n";
Generator & gen = at::globalContext().defaultGenerator(Backend::CPU);
cout << gen.seed() << "\n";
auto && C = at::globalContext();
auto & CUDAFloat = C.getType(Backend::CPU,ScalarType::Float);
auto t2 = CUDAFloat.zeros({4,4});
cout << &t2 << "\n";
cout << "AFTER GET TYPE " << &CUDAFloat << "\n";
cout << "STORAGE: " << CUDAFloat.storage(4).get() << "\n";
auto s = CUDAFloat.storage(4);
s->fill(7);
cout << "GET " << s->get(3).toFloat() << "\n";
auto t = CPU(Float).ones({4,4});
auto wha2 = CPU(Float).zeros({4,4}).add(t).sum();
cout << wha2.toDouble() << " <-ndim\n";
cout << t.sizes() << " " << t.strides() << "\n";
auto output = CPU(Float).ones(3);
at::Abs_updateOutput(t,output);
Type & T = CPU(Float);
Tensor x = T.randn({1,10});
Tensor prev_h = T.randn({1,20});
Tensor W_h = T.randn({20,20});
Tensor W_x = T.randn({20,10});
Tensor i2h = at::mm(W_x, x.t());
Tensor h2h = at::mm(W_h, prev_h.t());
Tensor next_h = i2h.add(h2h);
next_h = next_h.tanh();
auto r = CUDA(Float).copy(next_h);
cout << r << "\n";
cout << T.randn({10,10,2}) << "\n";
dispatch<Foo>(x.type(),x,prev_h);
return 0;
}