codemod primspec -> symbol, PrimSpec -> Symbolic
diff --git a/torch/autograd/_functions/basic_ops.py b/torch/autograd/_functions/basic_ops.py
index 5217172..73d7e22 100644
--- a/torch/autograd/_functions/basic_ops.py
+++ b/torch/autograd/_functions/basic_ops.py
@@ -7,7 +7,7 @@
 class Add(InplaceFunction):
 
     @staticmethod
-    def primspec(g, a, b, inplace=False):
+    def symbolic(g, a, b, inplace=False):
         # TODO: [Export inplace]
         return g.appendNode(g.create("Add", [a, b]))
 
@@ -29,7 +29,7 @@
 class Sub(InplaceFunction):
 
     @staticmethod
-    def primspec(g, a, b, inplace=False):
+    def symbolic(g, a, b, inplace=False):
         # TODO: [Export inplace]
         return g.appendNode(g.create("Sub", [a, b]))
 
@@ -51,7 +51,7 @@
 class Mul(Function):
 
     @staticmethod
-    def primspec(g, a, b, inplace=False):
+    def symbolic(g, a, b, inplace=False):
         # TODO: [Export inplace]
         return g.op("Mul", a, b)
 
@@ -231,7 +231,7 @@
 class Negate(InplaceFunction):
 
     @staticmethod
-    def primspec(g, i, inplace=False):
+    def symbolic(g, i, inplace=False):
         # TODO: [Export inplace]
         return g.appendNode(g.create("Scale", [i]).f_("scale", -1))
 
diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py
index 85c7acf..c7bcb3b2 100644
--- a/torch/autograd/_functions/blas.py
+++ b/torch/autograd/_functions/blas.py
@@ -16,7 +16,7 @@
 class Addmm(InplaceFunction):
 
     @staticmethod
-    def primspec(g, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False):
+    def symbolic(g, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False):
         # TODO: manually insert the necessary scaling, since ONNX doesn't
         # natively support it
         if alpha != 1 or beta != 1:
diff --git a/torch/autograd/_functions/pointwise.py b/torch/autograd/_functions/pointwise.py
index 726a8fb..a2a4276 100644
--- a/torch/autograd/_functions/pointwise.py
+++ b/torch/autograd/_functions/pointwise.py
@@ -55,7 +55,7 @@
 class Tanh(InplaceFunction):
 
     @staticmethod
-    def primspec(g, i, inplace=False):
+    def symbolic(g, i, inplace=False):
         # TODO: [Export inplace]
         return g.op("Tanh", i)
 
@@ -85,7 +85,7 @@
 class Sigmoid(InplaceFunction):
 
     @staticmethod
-    def primspec(g, i, inplace=False):
+    def symbolic(g, i, inplace=False):
         return g.op("Sigmoid", i)
 
     @staticmethod
diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py
index b6bed68..721b819 100644
--- a/torch/autograd/_functions/tensor.py
+++ b/torch/autograd/_functions/tensor.py
@@ -10,7 +10,7 @@
 class Index(Function):
 
     @staticmethod
-    def primspec(g, i, index):
+    def symbolic(g, i, index):
         # We should only expect index as an integer in this case.
         # We use "Slice" to get the index-th element in i,
         # Then we reduce the dimension using "Reshape".
@@ -95,7 +95,7 @@
 class Transpose(Function):
 
     @staticmethod
-    def primspec(g, i, dim1, dim2):
+    def symbolic(g, i, dim1, dim2):
         # NB: Swap dim1 and dim2, which is different from ONNX's
         # Transpose, which is actually a permute.
         if dim1 == dim2:
@@ -120,7 +120,7 @@
 class View(Function):
 
     @staticmethod
-    def primspec(g, i, sizes):
+    def symbolic(g, i, sizes):
         return g.op("Reshape", i, shape_i=sizes)
 
     @staticmethod
@@ -197,7 +197,7 @@
 class Permute(Function):
 
     @staticmethod
-    def primspec(g, input, dim_indices):
+    def symbolic(g, input, dim_indices):
         if dim_indices == list(range(0, len(dim_indices))):
             return input
         return g.op("Transpose", input, perm_i=dim_indices)
@@ -351,7 +351,7 @@
 class Concat(Function):
 
     @staticmethod
-    def primspec(g, dim, *inputs):
+    def symbolic(g, dim, *inputs):
         n, _ = g.op("Concat", *inputs, axis_i=dim, outputs=2)
         return n
 
@@ -409,7 +409,7 @@
 class Squeeze(InplaceFunction):
 
     @staticmethod
-    def primspec(g, input, dim, inplace=False):
+    def symbolic(g, input, dim, inplace=False):
         # TODO: [Export inplace]
         if dim is None:
             dims = []
@@ -594,7 +594,7 @@
 class Chunk(Function):
 
     @staticmethod
-    def primspec(g, i, num_chunks, dim=0):
+    def symbolic(g, i, num_chunks, dim=0):
         dim_size = i.type().sizes()[dim]
         split_size = (dim_size + num_chunks - 1) // num_chunks
         lengths = []
diff --git a/torch/csrc/autograd/functions/batch_normalization.h b/torch/csrc/autograd/functions/batch_normalization.h
index ad581f3..b565738 100644
--- a/torch/csrc/autograd/functions/batch_normalization.h
+++ b/torch/csrc/autograd/functions/batch_normalization.h
@@ -6,7 +6,7 @@
 
 #include "torch/csrc/autograd/function.h"
 #include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/primspec.h"
+#include "torch/csrc/autograd/symbolic.h"
 
 namespace torch { namespace autograd {
 
@@ -19,12 +19,12 @@
   bool cudnn_enabled;
 };
 
-struct BatchNormForward : public ForwardFunction<>, public BatchNormParams, public HasPrimSpec {
+struct BatchNormForward : public ForwardFunction<>, public BatchNormParams, public HasSymbolic {
   BatchNormForward(BatchNormParams params)
     : BatchNormParams(std::move(params)) {}
 
   virtual variable_list apply(const variable_list& inputs) override;
-  virtual jit::node_list primspec(PrimSpecContext* ctx, jit::node_list inputs) override;
+  virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
 };
 
 struct BatchNormBackward : public Function, public BatchNormParams {
diff --git a/torch/csrc/autograd/functions/convolution.h b/torch/csrc/autograd/functions/convolution.h
index 7317c0d..a328a62 100644
--- a/torch/csrc/autograd/functions/convolution.h
+++ b/torch/csrc/autograd/functions/convolution.h
@@ -8,7 +8,7 @@
 
 #include "torch/csrc/autograd/function.h"
 #include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/primspec.h"
+#include "torch/csrc/autograd/symbolic.h"
 
 #ifdef WITH_CUDNN
 #include "torch/csrc/cudnn/Conv.h"
@@ -38,12 +38,12 @@
   bool use_cudnn(const at::Tensor& input) const;
 };
 
-struct ConvForward : public ForwardFunction<>, public ConvParams, public HasPrimSpec {
+struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic {
   explicit ConvForward(ConvParams params) : ConvParams(std::move(params)) {}
 
   virtual std::string name() override;
   virtual variable_list apply(const variable_list& inputs) override;
-  virtual jit::node_list primspec(PrimSpecContext* ctx, jit::node_list inputs) override;
+  virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
 
   std::vector<int64_t> output_size(at::Tensor& input, at::Tensor& weight);
 };
diff --git a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp b/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
index c071a30..e51e858 100644
--- a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
+++ b/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
@@ -6,7 +6,7 @@
 namespace torch {
 namespace autograd {
 
-jit::node_list BatchNormForward::primspec(PrimSpecContext* ctx, jit::node_list inputs) {
+jit::node_list BatchNormForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
   auto & g = ctx->graph;
   // X, Scale, Bias
   auto bn = g->appendNode(g->create(jit::kSpatialBN,{inputs.at(0),inputs.at(1),inputs.at(2)}));
diff --git a/torch/csrc/autograd/functions/onnx/convolution.cpp b/torch/csrc/autograd/functions/onnx/convolution.cpp
index 059869b..1b4f756 100644
--- a/torch/csrc/autograd/functions/onnx/convolution.cpp
+++ b/torch/csrc/autograd/functions/onnx/convolution.cpp
@@ -25,7 +25,7 @@
 // passed here; it's done as an external addition.  This is less efficient
 // but this code should be temporary anyway.
 
-jit::node_list ConvForward::primspec(PrimSpecContext* ctx, jit::node_list inputs) {
+jit::node_list ConvForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
   auto & g = ctx->graph;
   // See Note [Caffe2ConvTranspose]
   auto n = g->create(!transposed ? jit::kConv : jit::kCaffe2ConvTranspose,
diff --git a/torch/csrc/autograd/primspec.h b/torch/csrc/autograd/symbolic.h
similarity index 68%
rename from torch/csrc/autograd/primspec.h
rename to torch/csrc/autograd/symbolic.h
index 381c462..9ce470e 100644
--- a/torch/csrc/autograd/primspec.h
+++ b/torch/csrc/autograd/symbolic.h
@@ -7,22 +7,22 @@
 
 namespace torch { namespace autograd {
 
-struct PrimSpecContext {
+struct SymbolicContext {
   jit::Graph* graph;
   const std::unordered_map<void*, jit::Node*>* buffer_map;
   int batch_norm_count = 0;
 };
 
-struct primspec_unconvertible : public std::runtime_error {
+struct symbolic_unconvertible : public std::runtime_error {
   using std::runtime_error::runtime_error;
 };
 
 
-struct HasPrimSpec {
+struct HasSymbolic {
   // Add some nodes to the ONNX protobuf, under the assumption that this node
   // as a whole has the represented inputs and outputs.  Raises a
-  // primspec_unconvertible exception if conversion is not supported.
-  virtual jit::node_list primspec(PrimSpecContext* ctx, jit::node_list inputs) = 0;
+  // symbolic_unconvertible exception if conversion is not supported.
+  virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) = 0;
 };
 
 }} // namespace torch::autograd
diff --git a/torch/csrc/onnx/export.cpp b/torch/csrc/onnx/export.cpp
index e63bd76..fb0978d 100644
--- a/torch/csrc/onnx/export.cpp
+++ b/torch/csrc/onnx/export.cpp
@@ -1,5 +1,5 @@
 #include "torch/csrc/onnx/export.h"
-#include "torch/csrc/autograd/primspec.h"
+#include "torch/csrc/autograd/symbolic.h"
 #include "torch/csrc/utils/python_numbers.h"
 #include "torch/csrc/utils/python_strings.h"
 #include "torch/csrc/Exceptions.h"
@@ -29,7 +29,7 @@
 std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& g,
                                   const std::unordered_map<void*, Node*>& old_buffer_map,
                                   bool verbose) {
-  torch::autograd::PrimSpecContext ctx;
+  torch::autograd::SymbolicContext ctx;
   std::unordered_map<Node*, Node*> env;
   std::shared_ptr<Graph> out_graph = std::make_shared<Graph>();
   ctx.graph = out_graph.get();
@@ -51,24 +51,24 @@
   ctx.buffer_map = &buffer_map;
   // put the new outputs in our environment map, and
   // copy the type from the input graph if they were not set by the
-  // primspec
+  // symbolic
   auto setOutputs = [&](Node * node, const node_list & outputs) {
     auto old_outputs = node->outputs();
-    // The primspec can produce less outputs than the actual IR node,
+    // The symbolic can produce less outputs than the actual IR node,
     // because many IR nodes have an implicit extra trailing output
     // of type Handle, which is irrelevant for the purposes of export.
-    // It's bad design to ask the primspec() implementers to actually
+    // It's bad design to ask the symbolic() implementers to actually
     // handle this!
-    JIT_ASSERTM(outputs.size() <= old_outputs.size(), "primspec produced too many outputs");
+    JIT_ASSERTM(outputs.size() <= old_outputs.size(), "symbolic produced too many outputs");
     size_t i = 0;
     for(auto & old : old_outputs) {
       // NB: There is at most one handle, and if it exists, it is the last input
       if(i >= outputs.size()) {
-        // primspecs do not deal with Handles at the moment, so we just
+        // symbolics do not deal with Handles at the moment, so we just
         // assert the handle isn't actually used.
         auto typ = old->typeOption();
         JIT_ASSERTM(typ && typ->kind() == jit::TypeKind::HandleType,
-          "primspec produced too few outputs");
+          "symbolic produced too few outputs");
         env[old] = nullptr;
         if (!old->uses().empty()) {
           throw std::runtime_error("In ONNX export, handles should be unused");
@@ -97,25 +97,25 @@
       // Selects are translated by multi-return nodes.
       JIT_ASSERT(env.count(value) > 0);
     IR_ELSEIFM(CppOp)
-      if (auto fn = std::dynamic_pointer_cast<autograd::HasPrimSpec>(value->fn)) {
-        auto outputs = fn->primspec(&ctx, fmap(node->inputs(), envFn));
+      if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
+        auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn));
         setOutputs(node, outputs);
       } else {
-        throw std::runtime_error("CppOp doesn't define primspec " + value->name());
+        throw std::runtime_error("CppOp doesn't define symbolic " + value->name());
       }
     IR_ELSEIFM(PythonOp)
       auto pyobj = py::handle(value->pyobj.get());
-      if(!py::hasattr(pyobj, "primspec"))
-        throw std::runtime_error("PythonOp doesn't define primspec " + value->name());
+      if(!py::hasattr(pyobj, "symbolic"))
+        throw std::runtime_error("PythonOp doesn't define symbolic " + value->name());
 
-      py::object primspec_fn = pyobj.attr("primspec");
+      py::object symbolic_fn = pyobj.attr("symbolic");
 
-      py::tuple py_primspec_args(1+value->cconv.size());
+      py::tuple py_symbolic_args(1+value->cconv.size());
 
       auto node_it = node->inputs().begin();
       auto scalar_it = value->scalar_args.begin();
       Py_ssize_t input_nr = 0;
-      py_primspec_args[input_nr++] = py::cast(ctx.graph);
+      py_symbolic_args[input_nr++] = py::cast(ctx.graph);
 
       for (auto arg_type : value->cconv) {
         py::object obj;
@@ -132,13 +132,13 @@
         } else {
           throw std::runtime_error("unexpected calling convention");
         }
-        py_primspec_args[input_nr++] = obj;
+        py_symbolic_args[input_nr++] = obj;
       }
-      py::object raw_output = py::reinterpret_steal<py::object>(PyObject_CallObject(primspec_fn.ptr(), py_primspec_args.ptr()));
+      py::object raw_output = py::reinterpret_steal<py::object>(PyObject_CallObject(symbolic_fn.ptr(), py_symbolic_args.ptr()));
       if(!raw_output)
         throw py::error_already_set();
       if(raw_output.ptr() == Py_None)
-        throw std::runtime_error("PythonOp's primspec returned None, indicating conversion not supported " + value->name());
+        throw std::runtime_error("PythonOp's symbolic returned None, indicating conversion not supported " + value->name());
       node_list outputs;
       if(py::isinstance<Node>(raw_output)) {
         outputs.push_back(py::cast<Node*>(raw_output));
@@ -282,7 +282,7 @@
     }
     if (node->kind() == kUndefined && node->uses().empty()) {
       // Undefined nodes never show up in ONNX; they're just a tool
-      // to help primspecs do the right thing.
+      // to help symbolics do the right thing.
       continue;
     }
     auto p_n = p_g->add_node();
diff --git a/torch/nn/_functions/dropout.py b/torch/nn/_functions/dropout.py
index ff2e4cf..3937cef 100644
--- a/torch/nn/_functions/dropout.py
+++ b/torch/nn/_functions/dropout.py
@@ -11,7 +11,7 @@
         return input.new().resize_as_(input)
 
     @staticmethod
-    def primspec(g, input, p=0.5, train=False, inplace=False):
+    def symbolic(g, input, p=0.5, train=False, inplace=False):
         if inplace:
             return None
         n = g.appendNode(g.create("Dropout", [input])
@@ -58,7 +58,7 @@
 class FeatureDropout(Dropout):
 
     @staticmethod
-    def primspec(input, p=0.5, train=False, inplace=False):
+    def symbolic(input, p=0.5, train=False, inplace=False):
         return None
 
     @staticmethod
diff --git a/torch/nn/_functions/thnn/activation.py b/torch/nn/_functions/thnn/activation.py
index 8dc086b..1beb36d 100644
--- a/torch/nn/_functions/thnn/activation.py
+++ b/torch/nn/_functions/thnn/activation.py
@@ -10,7 +10,7 @@
 class PReLU(Function):
 
     @staticmethod
-    def primspec(g, input, weight):
+    def symbolic(g, input, weight):
         # TODO: Properly support numel in type()
         if all(s == 1 for s in weight.type().sizes()):
             raise RuntimeError("single weight shared among input channels not supported")
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index eb72a45..21fbd9d 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -7,13 +7,13 @@
 from torch.autograd.function import Function, InplaceFunction, once_differentiable
 from torch._thnn import type2backend
 from .auto_double_backwards import double_backwards_fns
-from .auto_primspec import primspec_fns
+from .auto_symbolic import symbolic_fns
 
 from . import _all_functions
 
 
 def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters,
-                                   double_backwards_fn, primspec_fn):
+                                   double_backwards_fn, symbolic_fn):
     weight_arg_idx = -1
     for i, arg in enumerate(update_output.arguments):
         if arg.name.startswith('weight'):
@@ -28,8 +28,8 @@
         additional_arg_idx += 1
 
     @staticmethod
-    def primspec(*args, **kwargs):
-        a = primspec_fn(*args, **kwargs)
+    def symbolic(*args, **kwargs):
+        a = symbolic_fn(*args, **kwargs)
         return a
 
     @staticmethod
@@ -78,7 +78,7 @@
 
     backward_cls = type(class_name + "Backward", (Function,),
                         dict(forward=backward_cls_forward, backward=backward_cls_backward))
-    return type(class_name, (Function,), dict(forward=forward, backward=backward, primspec=primspec)), backward_cls
+    return type(class_name, (Function,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
 
 
 def _find_buffers(args, ignored_args):
@@ -94,7 +94,7 @@
 
 
 def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters,
-                         double_backwards_fn, primspec_fn):
+                         double_backwards_fn, symbolic_fn):
     def has_argument(fn, name):
         for arg in fn.arguments:
             if arg.name == name:
@@ -129,8 +129,8 @@
         return tuple(additional_args)
 
     @staticmethod
-    def primspec(*args, **kwargs):
-        return primspec_fn(*args, **kwargs)
+    def symbolic(*args, **kwargs):
+        return symbolic_fn(*args, **kwargs)
 
     @staticmethod
     def forward(ctx, input, *params):
@@ -251,7 +251,7 @@
     backward_cls = type(class_name + "Backward", (base_class,), dict(forward=backward_cls_forward,
                                                                      backward=backward_cls_backward))
 
-    return type(class_name, (base_class,), dict(forward=forward, backward=backward, primspec=primspec)), backward_cls
+    return type(class_name, (base_class,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
 
 
 def _generate_function_classes(scope_dict):
@@ -327,17 +327,17 @@
                     raise ValueError(class_name + " can only be differentiated once.")
                 return default_double_backwards_fn
             double_backwards_fn = make_default_double_backwards_fn(class_name)
-        primspec_fn = primspec_fns.get(class_name)
+        symbolic_fn = symbolic_fns.get(class_name)
         # This has to call a function to retain correct references to functions
         is_criterion_fn = 'Criterion' in fn
         if is_criterion_fn:
             cls, backward_cls = _make_function_class_criterion(class_name, update_output,
                                                                update_grad_input, acc_grad_parameters,
-                                                               double_backwards_fn, primspec_fn)
+                                                               double_backwards_fn, symbolic_fn)
         else:
             cls, backward_cls = _make_function_class(class_name, update_output,
                                                      update_grad_input, acc_grad_parameters,
-                                                     double_backwards_fn, primspec_fn)
+                                                     double_backwards_fn, symbolic_fn)
         scope_dict[class_name] = cls
         scope_dict[backward_cls.__name__] = backward_cls
         if not class_name.startswith('_'):
diff --git a/torch/nn/_functions/thnn/auto_primspec.py b/torch/nn/_functions/thnn/auto_symbolic.py
similarity index 61%
rename from torch/nn/_functions/thnn/auto_primspec.py
rename to torch/nn/_functions/thnn/auto_symbolic.py
index c77cb51..03cc188 100644
--- a/torch/nn/_functions/thnn/auto_primspec.py
+++ b/torch/nn/_functions/thnn/auto_symbolic.py
@@ -1,4 +1,4 @@
-def threshold_primspec(g, input, threshold=0, value=0, inplace=False):
+def threshold_symbolic(g, input, threshold=0, value=0, inplace=False):
     # TODO: [Export inplace]
     if threshold != 0:
         raise RuntimeError("Non-zero threshold in Threshold not supported")
@@ -7,12 +7,12 @@
     return g.op("Relu", input)
 
 
-def leakyrelu_primspec(g, input, negative_slope, inplace=False):
+def leakyrelu_symbolic(g, input, negative_slope, inplace=False):
     # TODO: [Export inplace]
     return g.op("LeakyRelu", input, alpha_f=negative_slope)
 
 
-primspec_fns = {
-    'Threshold': threshold_primspec,
-    'LeakyReLU': leakyrelu_primspec,
+symbolic_fns = {
+    'Threshold': threshold_symbolic,
+    'LeakyReLU': leakyrelu_symbolic,
 }
diff --git a/torch/nn/_functions/thnn/pooling.py b/torch/nn/_functions/thnn/pooling.py
index ee7ce08..55119e4 100644
--- a/torch/nn/_functions/thnn/pooling.py
+++ b/torch/nn/_functions/thnn/pooling.py
@@ -9,7 +9,7 @@
 class MaxPool1d(Function):
 
     @staticmethod
-    def primspec(g, input, kernel_size, stride=None, padding=0, dilation=1,
+    def symbolic(g, input, kernel_size, stride=None, padding=0, dilation=1,
                  ceil_mode=False):
         if ceil_mode:
             raise RuntimeError("ceil_mode not supported in MaxPool1d")
@@ -94,7 +94,7 @@
 class MaxPool2d(Function):
 
     @staticmethod
-    def primspec(g, input, kernel_size, stride=None, padding=0, dilation=1,
+    def symbolic(g, input, kernel_size, stride=None, padding=0, dilation=1,
                  ceil_mode=False):
         if ceil_mode:
             raise RuntimeError("ceil_mode not supported in MaxPool2d")
@@ -397,7 +397,7 @@
 class AvgPool2d(Function):
 
     @staticmethod
-    def primspec(g, input, kernel_size, stride=None, padding=0,
+    def symbolic(g, input, kernel_size, stride=None, padding=0,
                  ceil_mode=False, count_include_pad=True):
         if ceil_mode:
             raise RuntimeError("ceil_mode not supported in AvgPool2d")
diff --git a/torch/nn/_functions/thnn/sparse.py b/torch/nn/_functions/thnn/sparse.py
index 379b0ba..ecdcf93 100644
--- a/torch/nn/_functions/thnn/sparse.py
+++ b/torch/nn/_functions/thnn/sparse.py
@@ -9,7 +9,7 @@
 class Embedding(Function):
 
     @staticmethod
-    def primspec(g, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
+    def symbolic(g, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
                  sparse=False):
         if max_norm is not None:
             raise ValueError('Right now, re-norm is not supported.')