| #pragma once |
| |
| #include <ATen/CPUFunctions.h> |
| #include <ATen/NativeFunctions.h> |
| #include <torch/torch.h> |
| |
| struct DeepAndWide : torch::nn::Module { |
| DeepAndWide(int num_features = 50) { |
| mu_ = register_parameter("mu_", torch::randn({1, num_features})); |
| sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); |
| fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); |
| fc_b_ = register_parameter("fc_b_", torch::randn({1})); |
| } |
| |
| torch::Tensor forward( |
| torch::Tensor ad_emb_packed, |
| torch::Tensor user_emb, |
| torch::Tensor wide) { |
| auto wide_offset = wide + mu_; |
| auto wide_normalized = wide_offset * sigma_; |
| auto wide_noNaN = wide_normalized; |
| // Placeholder for ReplaceNaN |
| auto wide_preproc = torch::clamp(wide_noNaN, -10.0, 10.0); |
| |
| auto user_emb_t = torch::transpose(user_emb, 1, 2); |
| auto dp_unflatten = torch::bmm(ad_emb_packed, user_emb_t); |
| auto dp = torch::flatten(dp_unflatten, 1); |
| auto input = torch::cat({dp, wide_preproc}, 1); |
| auto fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); |
| auto pred = torch::sigmoid(fc1); |
| return pred; |
| } |
| torch::Tensor mu_, sigma_, fc_w_, fc_b_; |
| }; |
| |
| // Implementation using native functions and pre-allocated tensors. |
| // It could be used as a "speed of light" for static runtime. |
| struct DeepAndWideFast : torch::nn::Module { |
| DeepAndWideFast(int num_features = 50) { |
| mu_ = register_parameter("mu_", torch::randn({1, num_features})); |
| sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); |
| fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); |
| fc_b_ = register_parameter("fc_b_", torch::randn({1})); |
| allocated = false; |
| prealloc_tensors = {}; |
| } |
| |
| torch::Tensor forward( |
| torch::Tensor ad_emb_packed, |
| torch::Tensor user_emb, |
| torch::Tensor wide) { |
| torch::NoGradGuard no_grad; |
| if (!allocated) { |
| auto wide_offset = at::add(wide, mu_); |
| auto wide_normalized = at::mul(wide_offset, sigma_); |
| // Placeholder for ReplaceNaN |
| auto wide_preproc = at::cpu::clamp(wide_normalized, -10.0, 10.0); |
| |
| auto user_emb_t = at::native::transpose(user_emb, 1, 2); |
| auto dp_unflatten = at::cpu::bmm(ad_emb_packed, user_emb_t); |
| // auto dp = at::native::flatten(dp_unflatten, 1); |
| auto dp = dp_unflatten.view({dp_unflatten.size(0), 1}); |
| auto input = at::cpu::cat({dp, wide_preproc}, 1); |
| |
| // fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); |
| fc_w_t_ = torch::t(fc_w_); |
| auto fc1 = torch::addmm(fc_b_, input, fc_w_t_); |
| |
| auto pred = at::cpu::sigmoid(fc1); |
| |
| prealloc_tensors = { |
| wide_offset, |
| wide_normalized, |
| wide_preproc, |
| user_emb_t, |
| dp_unflatten, |
| dp, |
| input, |
| fc1, |
| pred}; |
| allocated = true; |
| |
| return pred; |
| } else { |
| // Potential optimization: add and mul could be fused together (e.g. with |
| // Eigen). |
| at::add_out(prealloc_tensors[0], wide, mu_); |
| at::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_); |
| |
| at::native::clip_out( |
| prealloc_tensors[1], -10.0, 10.0, prealloc_tensors[2]); |
| |
| // Potential optimization: original tensor could be pre-transposed. |
| // prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2); |
| if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) { |
| auto sizes = user_emb.sizes(); |
| auto strides = user_emb.strides(); |
| prealloc_tensors[3].set_( |
| user_emb.storage(), |
| 0, |
| {sizes[0], sizes[2], sizes[1]}, |
| {strides[0], strides[2], strides[1]}); |
| } |
| |
| // Potential optimization: call MKLDNN directly. |
| at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]); |
| |
| if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) { |
| // in unlikely case that the input tensor changed we need to |
| // reinitialize the view |
| prealloc_tensors[5] = |
| prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1}); |
| } |
| |
| // Potential optimization: we can replace cat with carefully constructed |
| // tensor views on the output that are passed to the _out ops above. |
| at::cpu::cat_outf( |
| {prealloc_tensors[5], prealloc_tensors[2]}, 1, prealloc_tensors[6]); |
| at::cpu::addmm_out( |
| prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_, 1, 1); |
| at::cpu::sigmoid_out(prealloc_tensors[7], prealloc_tensors[8]); |
| |
| return prealloc_tensors[8]; |
| } |
| } |
| torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_; |
| std::vector<torch::Tensor> prealloc_tensors; |
| bool allocated = false; |
| }; |
| |
| torch::jit::Module getDeepAndWideSciptModel(int num_features = 50); |
| |
| torch::jit::Module getTrivialScriptModel(); |
| |
| torch::jit::Module getLeakyReLUScriptModel(); |
| |
| torch::jit::Module getLeakyReLUConstScriptModel(); |
| |
| torch::jit::Module getLongScriptModel(); |
| |
| torch::jit::Module getSignedLog1pModel(); |