encode size in name...
diff --git a/torch/csrc/autograd/functions/toffee/batch_normalization.cpp b/torch/csrc/autograd/functions/toffee/batch_normalization.cpp
index d222835..562b661f 100644
--- a/torch/csrc/autograd/functions/toffee/batch_normalization.cpp
+++ b/torch/csrc/autograd/functions/toffee/batch_normalization.cpp
@@ -1,4 +1,5 @@
#include "torch/csrc/autograd/functions/batch_normalization.h"
+#include <sstream>
namespace torch {
namespace autograd {
@@ -27,8 +28,13 @@
ADD_ATTR("order",s,"NCHW");
ADD_ATTR("momentum",f,momentum);
- auto sm = "saved_mean"+std::to_string(ctx->batch_norm_count);
- auto sv = "saved_var"+std::to_string(ctx->batch_norm_count);
+ auto typ = inputs.at(1)->type()->cast<torch::jit::TensorType>();
+ int64_t the_size = typ->sizes()[0];
+ std::stringstream ss;
+ ss << the_size << "_" << ctx->batch_norm_count;
+ std::string suffix = ss.str();
+ auto sm = "saved_mean_"+suffix;
+ auto sv = "saved_var_"+suffix;
ctx->graph->add_input(sm);
ctx->graph->add_input(sv);
p_n->add_input(sm);