codemod primspec -> symbol, PrimSpec -> Symbolic
diff --git a/torch/autograd/_functions/basic_ops.py b/torch/autograd/_functions/basic_ops.py
index 5217172..73d7e22 100644
--- a/torch/autograd/_functions/basic_ops.py
+++ b/torch/autograd/_functions/basic_ops.py
@@ -7,7 +7,7 @@
class Add(InplaceFunction):
@staticmethod
- def primspec(g, a, b, inplace=False):
+ def symbolic(g, a, b, inplace=False):
# TODO: [Export inplace]
return g.appendNode(g.create("Add", [a, b]))
@@ -29,7 +29,7 @@
class Sub(InplaceFunction):
@staticmethod
- def primspec(g, a, b, inplace=False):
+ def symbolic(g, a, b, inplace=False):
# TODO: [Export inplace]
return g.appendNode(g.create("Sub", [a, b]))
@@ -51,7 +51,7 @@
class Mul(Function):
@staticmethod
- def primspec(g, a, b, inplace=False):
+ def symbolic(g, a, b, inplace=False):
# TODO: [Export inplace]
return g.op("Mul", a, b)
@@ -231,7 +231,7 @@
class Negate(InplaceFunction):
@staticmethod
- def primspec(g, i, inplace=False):
+ def symbolic(g, i, inplace=False):
# TODO: [Export inplace]
return g.appendNode(g.create("Scale", [i]).f_("scale", -1))
diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py
index 85c7acf..c7bcb3b2 100644
--- a/torch/autograd/_functions/blas.py
+++ b/torch/autograd/_functions/blas.py
@@ -16,7 +16,7 @@
class Addmm(InplaceFunction):
@staticmethod
- def primspec(g, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False):
+ def symbolic(g, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False):
# TODO: manually insert the necessary scaling, since ONNX doesn't
# natively support it
if alpha != 1 or beta != 1:
diff --git a/torch/autograd/_functions/pointwise.py b/torch/autograd/_functions/pointwise.py
index 726a8fb..a2a4276 100644
--- a/torch/autograd/_functions/pointwise.py
+++ b/torch/autograd/_functions/pointwise.py
@@ -55,7 +55,7 @@
class Tanh(InplaceFunction):
@staticmethod
- def primspec(g, i, inplace=False):
+ def symbolic(g, i, inplace=False):
# TODO: [Export inplace]
return g.op("Tanh", i)
@@ -85,7 +85,7 @@
class Sigmoid(InplaceFunction):
@staticmethod
- def primspec(g, i, inplace=False):
+ def symbolic(g, i, inplace=False):
return g.op("Sigmoid", i)
@staticmethod
diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py
index b6bed68..721b819 100644
--- a/torch/autograd/_functions/tensor.py
+++ b/torch/autograd/_functions/tensor.py
@@ -10,7 +10,7 @@
class Index(Function):
@staticmethod
- def primspec(g, i, index):
+ def symbolic(g, i, index):
# We should only expect index as an integer in this case.
# We use "Slice" to get the index-th element in i,
# Then we reduce the dimension using "Reshape".
@@ -95,7 +95,7 @@
class Transpose(Function):
@staticmethod
- def primspec(g, i, dim1, dim2):
+ def symbolic(g, i, dim1, dim2):
# NB: Swap dim1 and dim2, which is different from ONNX's
# Transpose, which is actually a permute.
if dim1 == dim2:
@@ -120,7 +120,7 @@
class View(Function):
@staticmethod
- def primspec(g, i, sizes):
+ def symbolic(g, i, sizes):
return g.op("Reshape", i, shape_i=sizes)
@staticmethod
@@ -197,7 +197,7 @@
class Permute(Function):
@staticmethod
- def primspec(g, input, dim_indices):
+ def symbolic(g, input, dim_indices):
if dim_indices == list(range(0, len(dim_indices))):
return input
return g.op("Transpose", input, perm_i=dim_indices)
@@ -351,7 +351,7 @@
class Concat(Function):
@staticmethod
- def primspec(g, dim, *inputs):
+ def symbolic(g, dim, *inputs):
n, _ = g.op("Concat", *inputs, axis_i=dim, outputs=2)
return n
@@ -409,7 +409,7 @@
class Squeeze(InplaceFunction):
@staticmethod
- def primspec(g, input, dim, inplace=False):
+ def symbolic(g, input, dim, inplace=False):
# TODO: [Export inplace]
if dim is None:
dims = []
@@ -594,7 +594,7 @@
class Chunk(Function):
@staticmethod
- def primspec(g, i, num_chunks, dim=0):
+ def symbolic(g, i, num_chunks, dim=0):
dim_size = i.type().sizes()[dim]
split_size = (dim_size + num_chunks - 1) // num_chunks
lengths = []
diff --git a/torch/csrc/autograd/functions/batch_normalization.h b/torch/csrc/autograd/functions/batch_normalization.h
index ad581f3..b565738 100644
--- a/torch/csrc/autograd/functions/batch_normalization.h
+++ b/torch/csrc/autograd/functions/batch_normalization.h
@@ -6,7 +6,7 @@
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/primspec.h"
+#include "torch/csrc/autograd/symbolic.h"
namespace torch { namespace autograd {
@@ -19,12 +19,12 @@
bool cudnn_enabled;
};
-struct BatchNormForward : public ForwardFunction<>, public BatchNormParams, public HasPrimSpec {
+struct BatchNormForward : public ForwardFunction<>, public BatchNormParams, public HasSymbolic {
BatchNormForward(BatchNormParams params)
: BatchNormParams(std::move(params)) {}
virtual variable_list apply(const variable_list& inputs) override;
- virtual jit::node_list primspec(PrimSpecContext* ctx, jit::node_list inputs) override;
+ virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
};
struct BatchNormBackward : public Function, public BatchNormParams {
diff --git a/torch/csrc/autograd/functions/convolution.h b/torch/csrc/autograd/functions/convolution.h
index 7317c0d..a328a62 100644
--- a/torch/csrc/autograd/functions/convolution.h
+++ b/torch/csrc/autograd/functions/convolution.h
@@ -8,7 +8,7 @@
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/autograd/primspec.h"
+#include "torch/csrc/autograd/symbolic.h"
#ifdef WITH_CUDNN
#include "torch/csrc/cudnn/Conv.h"
@@ -38,12 +38,12 @@
bool use_cudnn(const at::Tensor& input) const;
};
-struct ConvForward : public ForwardFunction<>, public ConvParams, public HasPrimSpec {
+struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic {
explicit ConvForward(ConvParams params) : ConvParams(std::move(params)) {}
virtual std::string name() override;
virtual variable_list apply(const variable_list& inputs) override;
- virtual jit::node_list primspec(PrimSpecContext* ctx, jit::node_list inputs) override;
+ virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
std::vector<int64_t> output_size(at::Tensor& input, at::Tensor& weight);
};
diff --git a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp b/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
index c071a30..e51e858 100644
--- a/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
+++ b/torch/csrc/autograd/functions/onnx/batch_normalization.cpp
@@ -6,7 +6,7 @@
namespace torch {
namespace autograd {
-jit::node_list BatchNormForward::primspec(PrimSpecContext* ctx, jit::node_list inputs) {
+jit::node_list BatchNormForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
auto & g = ctx->graph;
// X, Scale, Bias
auto bn = g->appendNode(g->create(jit::kSpatialBN,{inputs.at(0),inputs.at(1),inputs.at(2)}));
diff --git a/torch/csrc/autograd/functions/onnx/convolution.cpp b/torch/csrc/autograd/functions/onnx/convolution.cpp
index 059869b..1b4f756 100644
--- a/torch/csrc/autograd/functions/onnx/convolution.cpp
+++ b/torch/csrc/autograd/functions/onnx/convolution.cpp
@@ -25,7 +25,7 @@
// passed here; it's done as an external addition. This is less efficient
// but this code should be temporary anyway.
-jit::node_list ConvForward::primspec(PrimSpecContext* ctx, jit::node_list inputs) {
+jit::node_list ConvForward::symbolic(SymbolicContext* ctx, jit::node_list inputs) {
auto & g = ctx->graph;
// See Note [Caffe2ConvTranspose]
auto n = g->create(!transposed ? jit::kConv : jit::kCaffe2ConvTranspose,
diff --git a/torch/csrc/autograd/primspec.h b/torch/csrc/autograd/symbolic.h
similarity index 68%
rename from torch/csrc/autograd/primspec.h
rename to torch/csrc/autograd/symbolic.h
index 381c462..9ce470e 100644
--- a/torch/csrc/autograd/primspec.h
+++ b/torch/csrc/autograd/symbolic.h
@@ -7,22 +7,22 @@
namespace torch { namespace autograd {
-struct PrimSpecContext {
+struct SymbolicContext {
jit::Graph* graph;
const std::unordered_map<void*, jit::Node*>* buffer_map;
int batch_norm_count = 0;
};
-struct primspec_unconvertible : public std::runtime_error {
+struct symbolic_unconvertible : public std::runtime_error {
using std::runtime_error::runtime_error;
};
-struct HasPrimSpec {
+struct HasSymbolic {
// Add some nodes to the ONNX protobuf, under the assumption that this node
// as a whole has the represented inputs and outputs. Raises a
- // primspec_unconvertible exception if conversion is not supported.
- virtual jit::node_list primspec(PrimSpecContext* ctx, jit::node_list inputs) = 0;
+ // symbolic_unconvertible exception if conversion is not supported.
+ virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) = 0;
};
}} // namespace torch::autograd
diff --git a/torch/csrc/onnx/export.cpp b/torch/csrc/onnx/export.cpp
index e63bd76..fb0978d 100644
--- a/torch/csrc/onnx/export.cpp
+++ b/torch/csrc/onnx/export.cpp
@@ -1,5 +1,5 @@
#include "torch/csrc/onnx/export.h"
-#include "torch/csrc/autograd/primspec.h"
+#include "torch/csrc/autograd/symbolic.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/Exceptions.h"
@@ -29,7 +29,7 @@
std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& g,
const std::unordered_map<void*, Node*>& old_buffer_map,
bool verbose) {
- torch::autograd::PrimSpecContext ctx;
+ torch::autograd::SymbolicContext ctx;
std::unordered_map<Node*, Node*> env;
std::shared_ptr<Graph> out_graph = std::make_shared<Graph>();
ctx.graph = out_graph.get();
@@ -51,24 +51,24 @@
ctx.buffer_map = &buffer_map;
// put the new outputs in our environment map, and
// copy the type from the input graph if they were not set by the
- // primspec
+ // symbolic
auto setOutputs = [&](Node * node, const node_list & outputs) {
auto old_outputs = node->outputs();
- // The primspec can produce less outputs than the actual IR node,
+ // The symbolic can produce less outputs than the actual IR node,
// because many IR nodes have an implicit extra trailing output
// of type Handle, which is irrelevant for the purposes of export.
- // It's bad design to ask the primspec() implementers to actually
+ // It's bad design to ask the symbolic() implementers to actually
// handle this!
- JIT_ASSERTM(outputs.size() <= old_outputs.size(), "primspec produced too many outputs");
+ JIT_ASSERTM(outputs.size() <= old_outputs.size(), "symbolic produced too many outputs");
size_t i = 0;
for(auto & old : old_outputs) {
// NB: There is at most one handle, and if it exists, it is the last input
if(i >= outputs.size()) {
- // primspecs do not deal with Handles at the moment, so we just
+ // symbolics do not deal with Handles at the moment, so we just
// assert the handle isn't actually used.
auto typ = old->typeOption();
JIT_ASSERTM(typ && typ->kind() == jit::TypeKind::HandleType,
- "primspec produced too few outputs");
+ "symbolic produced too few outputs");
env[old] = nullptr;
if (!old->uses().empty()) {
throw std::runtime_error("In ONNX export, handles should be unused");
@@ -97,25 +97,25 @@
// Selects are translated by multi-return nodes.
JIT_ASSERT(env.count(value) > 0);
IR_ELSEIFM(CppOp)
- if (auto fn = std::dynamic_pointer_cast<autograd::HasPrimSpec>(value->fn)) {
- auto outputs = fn->primspec(&ctx, fmap(node->inputs(), envFn));
+ if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
+ auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn));
setOutputs(node, outputs);
} else {
- throw std::runtime_error("CppOp doesn't define primspec " + value->name());
+ throw std::runtime_error("CppOp doesn't define symbolic " + value->name());
}
IR_ELSEIFM(PythonOp)
auto pyobj = py::handle(value->pyobj.get());
- if(!py::hasattr(pyobj, "primspec"))
- throw std::runtime_error("PythonOp doesn't define primspec " + value->name());
+ if(!py::hasattr(pyobj, "symbolic"))
+ throw std::runtime_error("PythonOp doesn't define symbolic " + value->name());
- py::object primspec_fn = pyobj.attr("primspec");
+ py::object symbolic_fn = pyobj.attr("symbolic");
- py::tuple py_primspec_args(1+value->cconv.size());
+ py::tuple py_symbolic_args(1+value->cconv.size());
auto node_it = node->inputs().begin();
auto scalar_it = value->scalar_args.begin();
Py_ssize_t input_nr = 0;
- py_primspec_args[input_nr++] = py::cast(ctx.graph);
+ py_symbolic_args[input_nr++] = py::cast(ctx.graph);
for (auto arg_type : value->cconv) {
py::object obj;
@@ -132,13 +132,13 @@
} else {
throw std::runtime_error("unexpected calling convention");
}
- py_primspec_args[input_nr++] = obj;
+ py_symbolic_args[input_nr++] = obj;
}
- py::object raw_output = py::reinterpret_steal<py::object>(PyObject_CallObject(primspec_fn.ptr(), py_primspec_args.ptr()));
+ py::object raw_output = py::reinterpret_steal<py::object>(PyObject_CallObject(symbolic_fn.ptr(), py_symbolic_args.ptr()));
if(!raw_output)
throw py::error_already_set();
if(raw_output.ptr() == Py_None)
- throw std::runtime_error("PythonOp's primspec returned None, indicating conversion not supported " + value->name());
+ throw std::runtime_error("PythonOp's symbolic returned None, indicating conversion not supported " + value->name());
node_list outputs;
if(py::isinstance<Node>(raw_output)) {
outputs.push_back(py::cast<Node*>(raw_output));
@@ -282,7 +282,7 @@
}
if (node->kind() == kUndefined && node->uses().empty()) {
// Undefined nodes never show up in ONNX; they're just a tool
- // to help primspecs do the right thing.
+ // to help symbolics do the right thing.
continue;
}
auto p_n = p_g->add_node();
diff --git a/torch/nn/_functions/dropout.py b/torch/nn/_functions/dropout.py
index ff2e4cf..3937cef 100644
--- a/torch/nn/_functions/dropout.py
+++ b/torch/nn/_functions/dropout.py
@@ -11,7 +11,7 @@
return input.new().resize_as_(input)
@staticmethod
- def primspec(g, input, p=0.5, train=False, inplace=False):
+ def symbolic(g, input, p=0.5, train=False, inplace=False):
if inplace:
return None
n = g.appendNode(g.create("Dropout", [input])
@@ -58,7 +58,7 @@
class FeatureDropout(Dropout):
@staticmethod
- def primspec(input, p=0.5, train=False, inplace=False):
+ def symbolic(input, p=0.5, train=False, inplace=False):
return None
@staticmethod
diff --git a/torch/nn/_functions/thnn/activation.py b/torch/nn/_functions/thnn/activation.py
index 8dc086b..1beb36d 100644
--- a/torch/nn/_functions/thnn/activation.py
+++ b/torch/nn/_functions/thnn/activation.py
@@ -10,7 +10,7 @@
class PReLU(Function):
@staticmethod
- def primspec(g, input, weight):
+ def symbolic(g, input, weight):
# TODO: Properly support numel in type()
if all(s == 1 for s in weight.type().sizes()):
raise RuntimeError("single weight shared among input channels not supported")
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index eb72a45..21fbd9d 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -7,13 +7,13 @@
from torch.autograd.function import Function, InplaceFunction, once_differentiable
from torch._thnn import type2backend
from .auto_double_backwards import double_backwards_fns
-from .auto_primspec import primspec_fns
+from .auto_symbolic import symbolic_fns
from . import _all_functions
def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters,
- double_backwards_fn, primspec_fn):
+ double_backwards_fn, symbolic_fn):
weight_arg_idx = -1
for i, arg in enumerate(update_output.arguments):
if arg.name.startswith('weight'):
@@ -28,8 +28,8 @@
additional_arg_idx += 1
@staticmethod
- def primspec(*args, **kwargs):
- a = primspec_fn(*args, **kwargs)
+ def symbolic(*args, **kwargs):
+ a = symbolic_fn(*args, **kwargs)
return a
@staticmethod
@@ -78,7 +78,7 @@
backward_cls = type(class_name + "Backward", (Function,),
dict(forward=backward_cls_forward, backward=backward_cls_backward))
- return type(class_name, (Function,), dict(forward=forward, backward=backward, primspec=primspec)), backward_cls
+ return type(class_name, (Function,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
def _find_buffers(args, ignored_args):
@@ -94,7 +94,7 @@
def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters,
- double_backwards_fn, primspec_fn):
+ double_backwards_fn, symbolic_fn):
def has_argument(fn, name):
for arg in fn.arguments:
if arg.name == name:
@@ -129,8 +129,8 @@
return tuple(additional_args)
@staticmethod
- def primspec(*args, **kwargs):
- return primspec_fn(*args, **kwargs)
+ def symbolic(*args, **kwargs):
+ return symbolic_fn(*args, **kwargs)
@staticmethod
def forward(ctx, input, *params):
@@ -251,7 +251,7 @@
backward_cls = type(class_name + "Backward", (base_class,), dict(forward=backward_cls_forward,
backward=backward_cls_backward))
- return type(class_name, (base_class,), dict(forward=forward, backward=backward, primspec=primspec)), backward_cls
+ return type(class_name, (base_class,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
def _generate_function_classes(scope_dict):
@@ -327,17 +327,17 @@
raise ValueError(class_name + " can only be differentiated once.")
return default_double_backwards_fn
double_backwards_fn = make_default_double_backwards_fn(class_name)
- primspec_fn = primspec_fns.get(class_name)
+ symbolic_fn = symbolic_fns.get(class_name)
# This has to call a function to retain correct references to functions
is_criterion_fn = 'Criterion' in fn
if is_criterion_fn:
cls, backward_cls = _make_function_class_criterion(class_name, update_output,
update_grad_input, acc_grad_parameters,
- double_backwards_fn, primspec_fn)
+ double_backwards_fn, symbolic_fn)
else:
cls, backward_cls = _make_function_class(class_name, update_output,
update_grad_input, acc_grad_parameters,
- double_backwards_fn, primspec_fn)
+ double_backwards_fn, symbolic_fn)
scope_dict[class_name] = cls
scope_dict[backward_cls.__name__] = backward_cls
if not class_name.startswith('_'):
diff --git a/torch/nn/_functions/thnn/auto_primspec.py b/torch/nn/_functions/thnn/auto_symbolic.py
similarity index 61%
rename from torch/nn/_functions/thnn/auto_primspec.py
rename to torch/nn/_functions/thnn/auto_symbolic.py
index c77cb51..03cc188 100644
--- a/torch/nn/_functions/thnn/auto_primspec.py
+++ b/torch/nn/_functions/thnn/auto_symbolic.py
@@ -1,4 +1,4 @@
-def threshold_primspec(g, input, threshold=0, value=0, inplace=False):
+def threshold_symbolic(g, input, threshold=0, value=0, inplace=False):
# TODO: [Export inplace]
if threshold != 0:
raise RuntimeError("Non-zero threshold in Threshold not supported")
@@ -7,12 +7,12 @@
return g.op("Relu", input)
-def leakyrelu_primspec(g, input, negative_slope, inplace=False):
+def leakyrelu_symbolic(g, input, negative_slope, inplace=False):
# TODO: [Export inplace]
return g.op("LeakyRelu", input, alpha_f=negative_slope)
-primspec_fns = {
- 'Threshold': threshold_primspec,
- 'LeakyReLU': leakyrelu_primspec,
+symbolic_fns = {
+ 'Threshold': threshold_symbolic,
+ 'LeakyReLU': leakyrelu_symbolic,
}
diff --git a/torch/nn/_functions/thnn/pooling.py b/torch/nn/_functions/thnn/pooling.py
index ee7ce08..55119e4 100644
--- a/torch/nn/_functions/thnn/pooling.py
+++ b/torch/nn/_functions/thnn/pooling.py
@@ -9,7 +9,7 @@
class MaxPool1d(Function):
@staticmethod
- def primspec(g, input, kernel_size, stride=None, padding=0, dilation=1,
+ def symbolic(g, input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False):
if ceil_mode:
raise RuntimeError("ceil_mode not supported in MaxPool1d")
@@ -94,7 +94,7 @@
class MaxPool2d(Function):
@staticmethod
- def primspec(g, input, kernel_size, stride=None, padding=0, dilation=1,
+ def symbolic(g, input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False):
if ceil_mode:
raise RuntimeError("ceil_mode not supported in MaxPool2d")
@@ -397,7 +397,7 @@
class AvgPool2d(Function):
@staticmethod
- def primspec(g, input, kernel_size, stride=None, padding=0,
+ def symbolic(g, input, kernel_size, stride=None, padding=0,
ceil_mode=False, count_include_pad=True):
if ceil_mode:
raise RuntimeError("ceil_mode not supported in AvgPool2d")
diff --git a/torch/nn/_functions/thnn/sparse.py b/torch/nn/_functions/thnn/sparse.py
index 379b0ba..ecdcf93 100644
--- a/torch/nn/_functions/thnn/sparse.py
+++ b/torch/nn/_functions/thnn/sparse.py
@@ -9,7 +9,7 @@
class Embedding(Function):
@staticmethod
- def primspec(g, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
+ def symbolic(g, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse=False):
if max_norm is not None:
raise ValueError('Right now, re-norm is not supported.')