blob: 5b18c96364ba212281e7f534d1fa0e4565db1a1d [file] [log] [blame]
#pragma once
#include <ATen/CPUFunctions.h>
#include <ATen/NativeFunctions.h>
#include <torch/torch.h>
struct DeepAndWide : torch::nn::Module {
DeepAndWide(int num_features = 50) {
mu_ = register_parameter("mu_", torch::randn({1, num_features}));
sigma_ = register_parameter("sigma_", torch::randn({1, num_features}));
fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1}));
fc_b_ = register_parameter("fc_b_", torch::randn({1}));
}
torch::Tensor forward(
torch::Tensor ad_emb_packed,
torch::Tensor user_emb,
torch::Tensor wide) {
auto wide_offset = wide + mu_;
auto wide_normalized = wide_offset * sigma_;
auto wide_noNaN = wide_normalized;
// Placeholder for ReplaceNaN
auto wide_preproc = torch::clamp(wide_noNaN, -10.0, 10.0);
auto user_emb_t = torch::transpose(user_emb, 1, 2);
auto dp_unflatten = torch::bmm(ad_emb_packed, user_emb_t);
auto dp = torch::flatten(dp_unflatten, 1);
auto input = torch::cat({dp, wide_preproc}, 1);
auto fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_);
auto pred = torch::sigmoid(fc1);
return pred;
}
torch::Tensor mu_, sigma_, fc_w_, fc_b_;
};
// Implementation using native functions and pre-allocated tensors.
// It could be used as a "speed of light" for static runtime.
struct DeepAndWideFast : torch::nn::Module {
DeepAndWideFast(int num_features = 50) {
mu_ = register_parameter("mu_", torch::randn({1, num_features}));
sigma_ = register_parameter("sigma_", torch::randn({1, num_features}));
fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1}));
fc_b_ = register_parameter("fc_b_", torch::randn({1}));
allocated = false;
prealloc_tensors = {};
}
torch::Tensor forward(
torch::Tensor ad_emb_packed,
torch::Tensor user_emb,
torch::Tensor wide) {
torch::NoGradGuard no_grad;
if (!allocated) {
auto wide_offset = at::add(wide, mu_);
auto wide_normalized = at::mul(wide_offset, sigma_);
// Placeholder for ReplaceNaN
auto wide_preproc = at::cpu::clamp(wide_normalized, -10.0, 10.0);
auto user_emb_t = at::native::transpose(user_emb, 1, 2);
auto dp_unflatten = at::cpu::bmm(ad_emb_packed, user_emb_t);
// auto dp = at::native::flatten(dp_unflatten, 1);
auto dp = dp_unflatten.view({dp_unflatten.size(0), 1});
auto input = at::cpu::cat({dp, wide_preproc}, 1);
// fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_);
fc_w_t_ = torch::t(fc_w_);
auto fc1 = torch::addmm(fc_b_, input, fc_w_t_);
auto pred = at::cpu::sigmoid(fc1);
prealloc_tensors = {
wide_offset,
wide_normalized,
wide_preproc,
user_emb_t,
dp_unflatten,
dp,
input,
fc1,
pred};
allocated = true;
return pred;
} else {
// Potential optimization: add and mul could be fused together (e.g. with
// Eigen).
at::add_out(prealloc_tensors[0], wide, mu_);
at::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_);
at::native::clip_out(
prealloc_tensors[1], -10.0, 10.0, prealloc_tensors[2]);
// Potential optimization: original tensor could be pre-transposed.
// prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2);
if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) {
auto sizes = user_emb.sizes();
auto strides = user_emb.strides();
prealloc_tensors[3].set_(
user_emb.storage(),
0,
{sizes[0], sizes[2], sizes[1]},
{strides[0], strides[2], strides[1]});
}
// Potential optimization: call MKLDNN directly.
at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]);
if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) {
// in unlikely case that the input tensor changed we need to
// reinitialize the view
prealloc_tensors[5] =
prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1});
}
// Potential optimization: we can replace cat with carefully constructed
// tensor views on the output that are passed to the _out ops above.
at::cpu::cat_outf(
{prealloc_tensors[5], prealloc_tensors[2]}, 1, prealloc_tensors[6]);
at::cpu::addmm_out(
prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_, 1, 1);
at::cpu::sigmoid_out(prealloc_tensors[7], prealloc_tensors[8]);
return prealloc_tensors[8];
}
}
torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_;
std::vector<torch::Tensor> prealloc_tensors;
bool allocated = false;
};
torch::jit::Module getDeepAndWideSciptModel(int num_features = 50);
torch::jit::Module getTrivialScriptModel();
torch::jit::Module getLeakyReLUScriptModel();
torch::jit::Module getLeakyReLUConstScriptModel();
torch::jit::Module getLongScriptModel();
torch::jit::Module getSignedLog1pModel();