batch norm primspec stub
diff --git a/torch/csrc/autograd/functions/batch_normalization.h b/torch/csrc/autograd/functions/batch_normalization.h
index d12574e..fa07a5e 100644
--- a/torch/csrc/autograd/functions/batch_normalization.h
+++ b/torch/csrc/autograd/functions/batch_normalization.h
@@ -6,6 +6,7 @@
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/autograd/primspec.h"
namespace torch { namespace autograd {
@@ -18,11 +19,12 @@
bool cudnn_enabled;
};
-struct BatchNormForward : public Function, public BatchNormParams {
+struct BatchNormForward : public Function, public BatchNormParams, public HasPrimSpec {
BatchNormForward(BatchNormParams params)
: BatchNormParams(std::move(params)) {}
virtual variable_list apply(const variable_list& inputs) override;
+ HAS_PRIMSPEC;
};
struct BatchNormBackward : public Function, public BatchNormParams {
diff --git a/torch/csrc/autograd/functions/toffee/batch_normalization.cpp b/torch/csrc/autograd/functions/toffee/batch_normalization.cpp
new file mode 100644
index 0000000..f0fd1e5
--- /dev/null
+++ b/torch/csrc/autograd/functions/toffee/batch_normalization.cpp
@@ -0,0 +1,11 @@
+#include "torch/csrc/autograd/functions/batch_normalization.h"
+
+namespace torch {
+namespace autograd {
+
+void BatchNormForward::primspec(PrimSpecContext* ctx, jit::node_list inputs, jit::node_list outputs) {
+ // TODO: implement
+}
+
+} // torch::autograd
+} // torch