| #pragma once |
| |
| #if defined(USE_GTEST) |
| #include <gtest/gtest.h> |
| #include <test/cpp/common/support.h> |
| #else |
| #include "c10/util/Exception.h" |
| #define ASSERT_EQ(x, y) AT_ASSERT((x) == (y)) |
| #define ASSERT_NE(x, y) AT_ASSERT((x) != (y)) |
| #define ASSERT_TRUE AT_ASSERT |
| #define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) |
| #define ASSERT_THROWS_WITH(statement, substring) \ |
| try { \ |
| (void)statement; \ |
| ASSERT_TRUE(false); \ |
| } catch (const std::exception& e) { \ |
| ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ |
| } |
| #define ASSERT_ANY_THROW(statement) \ |
| bool threw = false; \ |
| try { \ |
| (void)statement; \ |
| } catch (const std::exception& e) { \ |
| threw = true; \ |
| } \ |
| ASSERT_TRUE(threw); \ |
| |
| #endif // defined(USE_GTEST) |
| |
| #include "torch/csrc/autograd/variable.h" |
| #include "torch/csrc/jit/argument_spec.h" |
| #include "torch/csrc/jit/assertions.h" |
| #include "torch/csrc/jit/attributes.h" |
| #include "torch/csrc/jit/autodiff.h" |
| #include "torch/csrc/jit/code_template.h" |
| #include "torch/csrc/jit/custom_operator.h" |
| #include "torch/csrc/jit/dynamic_dag.h" |
| #include "torch/csrc/jit/fuser/interface.h" |
| #include "torch/csrc/jit/interned_strings.h" |
| #include "torch/csrc/jit/interpreter.h" |
| #include "torch/csrc/jit/ir.h" |
| #include "torch/csrc/jit/operator.h" |
| #include "torch/csrc/jit/passes/alias_analysis.h" |
| #include "torch/csrc/jit/passes/common_subexpression_elimination.h" |
| #include "torch/csrc/jit/passes/constant_propagation.h" |
| #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" |
| #include "torch/csrc/jit/passes/dead_code_elimination.h" |
| #include "torch/csrc/jit/passes/lower_grad_of.h" |
| #include "torch/csrc/jit/passes/lower_tuples.h" |
| #include "torch/csrc/jit/passes/requires_grad_analysis.h" |
| #include "torch/csrc/jit/passes/shape_analysis.h" |
| #include "torch/csrc/jit/passes/utils/subgraph_utils.h" |
| #include "torch/csrc/jit/symbolic_variable.h" |
| #include "torch/csrc/jit/tracer.h" |
| #include "torch/csrc/utils/hash.h" |
| #include "torch/csrc/variable_tensor_functions.h" |
| #include "torch/csrc/autograd/generated/variable_factories.h" |
| |
| #include "torch/csrc/autograd/engine.h" |
| #include "torch/csrc/autograd/variable.h" |
| |
| #include "torch/csrc/jit/graph_executor.h" |
| #include "torch/csrc/jit/ivalue.h" |
| #include "torch/csrc/jit/script/compiler.h" |
| #include "torch/csrc/jit/script/module.h" |
| |
| #include "onnx/onnx_pb.h" |
| |
| #include <ATen/ATen.h> |
| |
| #include <c10/util/Exception.h> |
| |
| #include <algorithm> |
| #include <cstddef> |
| #include <functional> |
| #include <iostream> |
| #include <memory> |
| #include <stdexcept> |
| #include <string> |
| #include <tuple> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| namespace torch { |
| namespace jit { |
| namespace { |
| |
| using Var = SymbolicVariable; |
| |
| using namespace torch::autograd; |
| |
| template <typename T> |
| std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) { |
| size_t i = 0; |
| out << "{"; |
| for (auto&& e : list) { |
| if (i++ > 0) |
| out << ", "; |
| out << e; |
| } |
| out << "}"; |
| return out; |
| } |
| auto ct = CodeTemplate(R"( |
| int foo($args) { |
| |
| $bar |
| $bar |
| $a+$b |
| } |
| int commatest(int a${,stuff}) |
| int notest(int a${,empty,}) |
| )"); |
| auto ct_expect = R"( |
| int foo(hi, 8) { |
| |
| what |
| on many |
| lines... |
| 7 |
| what |
| on many |
| lines... |
| 7 |
| 3+4 |
| } |
| int commatest(int a, things..., others) |
| int notest(int a) |
| )"; |
| |
| void testCodeTemplate() { |
| { |
| TemplateEnv e; |
| e.s("hi", "foo"); |
| e.v("what", {"is", "this"}); |
| TemplateEnv c(e); |
| c.s("hi", "foo2"); |
| ASSERT_EQ(e.s("hi"), "foo"); |
| ASSERT_EQ(c.s("hi"), "foo2"); |
| ASSERT_EQ(e.v("what")[0], "is"); |
| } |
| |
| { |
| TemplateEnv e; |
| e.v("args", {"hi", "8"}); |
| e.v("bar", {"what\non many\nlines...", "7"}); |
| e.s("a", "3"); |
| e.s("b", "4"); |
| e.v("stuff", {"things...", "others"}); |
| e.v("empty", {}); |
| auto s = ct.format(e); |
| // std::cout << "'" << s << "'\n"; |
| // std::cout << "'" << ct_expect << "'\n"; |
| ASSERT_EQ(s, ct_expect); |
| } |
| } |
| |
| Value* appendNewNode(NodeKind kind, Graph& graph, ArrayRef<Value*> inputs) { |
| return graph.appendNode(graph.create(kind, inputs))->output(); |
| } |
| |
| void testFusion() { |
| auto testSimple = [&] { |
| Graph graph; |
| Var i0 = Var::asNewInput(graph); |
| Var i1 = Var::asNewInput(graph); |
| auto o0 = i0 * i1; |
| o0.addAsOutput(); |
| auto a = at::rand({3, 4}, at::kCUDA); |
| auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1); |
| auto o = at::zeros({3, 4}, at::kCUDA); |
| auto outputs = debugLaunchGraph(graph, {a, b}); |
| ASSERT_EQ(outputs.size(), 1); |
| auto o2 = a * b; |
| float max_diff = (o2 - outputs[0]).abs().max().item<double>(); |
| // std::cout << "max diff: " << max_diff << "\n"; |
| ASSERT_EQ(max_diff, 0); |
| }; |
| testSimple(); |
| |
| auto testOne = [&](int ti, int tj, int toi, int toj) { |
| Graph graph; |
| |
| Var i0 = Var::asNewInput(graph); |
| Var i1 = Var::asNewInput(graph); |
| Var i2 = Var::asNewInput(graph); |
| Var i3 = Var::asNewInput(graph); |
| Var i4 = Var::asNewInput(graph); |
| |
| auto p22 = i4.sigmoid(); |
| auto p20 = i3.sigmoid(); |
| auto p18 = i2.tanh(); |
| auto p16 = i1.sigmoid(); |
| auto p14 = p20 * i0; |
| auto p11 = p22 * p18; |
| auto o1 = p14 + p11; |
| auto p5 = o1.tanh(); |
| auto o0 = p16 * p5; |
| o0.addAsOutput(); |
| o1.addAsOutput(); |
| |
| graph.lint(); |
| |
| std::vector<at::Tensor> inputs; |
| // We want to generate input/output tensors with dimension 128x128x32, but |
| // with different internal strides. To do this, we generate a tensor |
| // with the "wrong" dimensions, and then use transpose to get an |
| // appropriately sized view. |
| for (size_t i = 0; i < graph.inputs().size(); i++) { |
| std::vector<int64_t> dims = {128, 128, 32}; |
| std::swap(dims[ti], dims[tj]); |
| inputs.push_back(at::rand(dims, at::kCUDA).transpose(ti, tj)); |
| } |
| |
| auto t22 = inputs[4].sigmoid(); |
| auto t20 = inputs[3].sigmoid(); |
| auto t18 = inputs[2].tanh(); |
| auto t16 = inputs[1].sigmoid(); |
| auto t14 = t20 * inputs[0]; |
| auto t11 = t22 * t18; |
| auto out1 = t14 + t11; |
| auto t5 = out1.tanh(); |
| auto out0 = t16 * t5; |
| |
| auto outputs = debugLaunchGraph(graph, inputs); |
| ASSERT_EQ(outputs.size(), graph.outputs().size()); |
| ASSERT_TRUE(out0.is_same_size(outputs.front())); |
| float max_diff = (outputs.front() - out0).abs().max().item<double>(); |
| ASSERT_TRUE(max_diff < 1e-6); |
| }; |
| testOne(0, 0, 0, 0); |
| testOne(0, 1, 0, 0); |
| testOne(1, 2, 0, 0); |
| testOne(0, 2, 0, 0); |
| |
| testOne(0, 0, 0, 1); |
| testOne(0, 1, 1, 2); |
| testOne(1, 2, 0, 2); |
| |
| auto createFusedConcat = |
| [](Graph& graph, at::ArrayRef<Value*> inputs, int64_t dim) -> Value* { |
| return graph |
| .insertNode(graph.create(prim::FusedConcat, inputs)->i_(attr::dim, dim)) |
| ->output(); |
| }; |
| |
| auto testConcat = [&](int dim) { |
| Graph graph; |
| Var i0 = Var::asNewInput(graph); |
| Var i1 = Var::asNewInput(graph); |
| auto o0 = i0 * i1; |
| o0.addAsOutput(); |
| Var(createFusedConcat(graph, {i0, o0}, dim)).addAsOutput(); |
| |
| auto a = at::rand({3, 4, 5}, at::kCUDA); |
| auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1); |
| |
| auto o_r = a * b; |
| auto o2_r = at::cat({a, o_r}, dim); |
| auto outputs = debugLaunchGraph(graph, {a, b}); |
| ASSERT_EQ(outputs.size(), 2); |
| |
| float max_diff = (o_r - outputs[0]).abs().max().item<double>(); |
| ASSERT_EQ(max_diff, 0); |
| float max_diff2 = (o2_r - outputs[1]).abs().max().item<double>(); |
| ASSERT_EQ(max_diff2, 0); |
| }; |
| testConcat(0); |
| testConcat(1); |
| testConcat(2); |
| } |
| |
| struct Attr : public Attributes<Attr> {}; |
| void testAttributes() { |
| auto one = attr::alpha; |
| auto two = attr::device; |
| auto three = attr::end; |
| auto four = attr::perm; |
| Attr attr; |
| attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what"); |
| ASSERT_EQ(attr.f(one), 3.4); |
| ASSERT_EQ(attr.s(three), "what"); |
| ASSERT_EQ(attr.i(two), 5); |
| attr.s_(one, "no"); |
| ASSERT_EQ(attr.s(one), "no"); |
| ASSERT_TRUE(attr.hasAttribute(three)); |
| ASSERT_TRUE(!attr.hasAttribute(four)); |
| attr.ss_(two, {"hi", "now"}); |
| ASSERT_EQ(attr.ss(two).at(1), "now"); |
| |
| Attr attr2; |
| attr2.copyAttributes(attr); |
| ASSERT_EQ(attr2.s(one), "no"); |
| attr2.f_(one, 5); |
| ASSERT_EQ(attr.s(one), "no"); |
| ASSERT_EQ(attr2.f(one), 5); |
| } |
| |
| void testInternedStrings() { |
| ASSERT_EQ(prim::Param, Symbol::prim("Param")); |
| ASSERT_EQ(prim::Return, Symbol::prim("Return")); |
| ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return")); |
| ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return")); |
| Symbol newsym = Symbol::aten("__NEW_SYMBOL"); |
| size_t symstart = newsym; |
| ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL")); |
| // TODO: This test is a bit too close to the implementation details. |
| ASSERT_EQ(Symbol::aten("What"), symstart + 1); |
| ASSERT_EQ(Symbol::aten("What2"), symstart + 2); |
| ASSERT_EQ(Symbol::aten("What"), symstart + 1); |
| ASSERT_EQ(Symbol::aten("What2"), symstart + 2); |
| ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2")); |
| } |
| |
| void testFromQualString() { |
| ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param")); |
| ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm")); |
| ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM")); |
| ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value")); |
| ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope("")); |
| ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string("")); |
| ASSERT_EQ( |
| Symbol::fromQualString("::").ns().toQualString(), |
| std::string("namespaces::")); |
| ASSERT_EQ( |
| Symbol::fromQualString("new_ns::param").toUnqualString(), |
| std::string("param")); |
| ASSERT_EQ( |
| Symbol::fromQualString("new_ns::param").ns().toUnqualString(), |
| std::string("new_ns")); |
| ASSERT_EQ( |
| Symbol::fromQualString("new_ns::param").ns(), |
| Symbol::fromQualString("namespaces::new_ns")); |
| |
| auto bad_inputs = {"scope", ":", ""}; |
| for (auto input : bad_inputs) { |
| try { |
| Symbol::fromQualString(input); |
| ASSERT_TRUE(0); |
| } catch (const std::exception& c) { |
| } |
| } |
| } |
| |
| at::Tensor t_use(at::Tensor x) { |
| return x; |
| } |
| at::Tensor t_def(at::Tensor x) { |
| return x.t(); |
| } |
| |
| // given the difference of output vs expected tensor, check whether the |
| // difference is within a relative tolerance range. This is a standard way of |
| // matching tensor values upto certain precision |
| bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) { |
| double maxValue = 0.0; |
| for (auto& tensor : inputs) { |
| maxValue = fmax(tensor.abs().max().item<float>(), maxValue); |
| } |
| return diff.abs().max().item<float>() < 2e-6 * maxValue; |
| } |
| bool almostEqual(const at::Tensor& a, const at::Tensor& b) { |
| return checkRtol(a - b, {a, b}); |
| } |
| |
| bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { |
| return (a - b).abs().max().item<float>() == 0.f; |
| } |
| |
| std::pair<at::Tensor, at::Tensor> lstm( |
| at::Tensor input, |
| at::Tensor hx, |
| at::Tensor cx, |
| at::Tensor w_ih, |
| at::Tensor w_hh) { |
| auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh)); |
| |
| auto chunked_gates = gates.chunk(4, 1); |
| auto ingate = chunked_gates[0]; |
| auto forgetgate = chunked_gates[1]; |
| auto cellgate = chunked_gates[2]; |
| auto outgate = chunked_gates[3]; |
| |
| ingate = ingate.sigmoid(); |
| outgate = outgate.sigmoid(); |
| cellgate = cellgate.tanh(); |
| forgetgate = forgetgate.sigmoid(); |
| |
| auto cy = (forgetgate * cx) + (ingate * cellgate); |
| auto hy = outgate * cy.tanh(); |
| |
| return {hy, cy}; |
| } |
| |
| std::tuple<Var, Var> build_lstm_body( |
| Graph& g, |
| Var input, |
| Var hx, |
| Var cx, |
| Var w_ih, |
| Var w_hh) { |
| auto gates = input.mm(w_ih); |
| gates = gates + hx.mm(w_hh); |
| auto outputs = gates.chunk(4, 1); |
| auto ingate = outputs[0]; |
| auto forgetgate = outputs[1]; |
| auto cellgate = outputs[2]; |
| auto outgate = outputs[3]; |
| ingate = ingate.sigmoid(); |
| outgate = outgate.sigmoid(); |
| cellgate = cellgate.tanh(); |
| forgetgate = forgetgate.sigmoid(); |
| |
| auto cy = forgetgate * cx; |
| cy = cy + ingate * cellgate; |
| auto hy = outgate * cy.tanh(); |
| |
| return std::make_tuple(hy, cy); |
| } |
| |
| std::shared_ptr<Graph> build_lstm() { |
| auto r = std::make_shared<Graph>(); |
| auto& g = *r; |
| Value* input = g.addInput(); |
| Value* hx = g.addInput(); |
| Value* cx = g.addInput(); |
| Value* w_ih = g.addInput(); |
| Value* w_hh = g.addInput(); |
| |
| Var hy; |
| Var cy; |
| std::tie(hy, cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh); |
| |
| hy.addAsOutput(); |
| cy.addAsOutput(); |
| g.lint(); |
| |
| return r; |
| } |
| |
| void run(InterpreterState & interp, const std::vector<at::Tensor> & inputs, std::vector<at::Tensor> & outputs) { |
| std::vector<IValue> stack(inputs.begin(), inputs.end()); |
| interp.run(stack); |
| outputs.clear(); |
| for (auto& ivalue : stack) { |
| outputs.push_back(std::move(ivalue).toTensor()); |
| } |
| } |
| |
| std::pair<tensor_list, tensor_list> runGradient( |
| Gradient& grad_spec, |
| tensor_list& tensors_in, |
| tensor_list& tensor_grads_in) { |
| tensor_list tensors_out, tensor_grads_out; |
| Code f_code{grad_spec.f}, df_code{grad_spec.df}; |
| InterpreterState f_interpreter{f_code}, df_interpreter{df_code}; |
| |
| run(f_interpreter, tensors_in, tensors_out); |
| |
| tensor_list df_inputs; |
| df_inputs.insert( |
| df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end()); |
| for (auto offset : grad_spec.df_input_captured_inputs) |
| df_inputs.push_back(tensors_in[offset]); |
| for (auto offset : grad_spec.df_input_captured_outputs) |
| df_inputs.push_back(tensors_out[offset]); |
| run(df_interpreter, df_inputs, tensor_grads_out); |
| |
| // Outputs of f needs to be sliced |
| tensors_out.erase( |
| tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end()); |
| return std::make_pair(tensors_out, tensor_grads_out); |
| } |
| |
| void assertAllClose(const tensor_list& a, const tensor_list& b) { |
| ASSERT_EQ(a.size(), b.size()); |
| for (size_t i = 0; i < a.size(); ++i) { |
| ASSERT_TRUE(a[i].is_same_size(b[i])); |
| ASSERT_TRUE(a[i].allclose(b[i])); |
| } |
| } |
| |
| void testInterp() { |
| constexpr int batch_size = 4; |
| constexpr int input_size = 256; |
| constexpr int seq_len = 32; |
| |
| int hidden_size = 2 * input_size; |
| |
| auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA); |
| auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); |
| auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); |
| auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); |
| auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); |
| |
| auto lstm_g = build_lstm(); |
| Code lstm_function(lstm_g); |
| std::vector<at::Tensor> outputs; |
| InterpreterState lstm_interp(lstm_function); |
| run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}, outputs); |
| std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); |
| |
| // std::cout << almostEqual(outputs[0],hx) << "\n"; |
| ASSERT_TRUE(exactlyEqual(outputs[0], hx)); |
| ASSERT_TRUE(exactlyEqual(outputs[1], cx)); |
| } |
| |
| void testTHNNConv() { |
| std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W |
| std::vector<int64_t> kernel_size = {3, 5}; |
| std::vector<int64_t> stride = {1, 2}; |
| std::vector<int64_t> padding = {2, 1}; |
| constexpr int out_channels = 5; |
| |
| // make inputs |
| at::Tensor input = torch::randn(input_size); |
| at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]}); |
| at::Tensor bias = torch::randn({out_channels}); |
| |
| // run forward eagerly |
| at::Tensor output, finput, fgradinput; |
| std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size, |
| bias, stride, padding); |
| |
| // make grad_outputs |
| at::Tensor grad_output = torch::randn_like(output); |
| at::Tensor grad_finput = torch::zeros_like(finput); |
| at::Tensor grad_fgradinput = torch::zeros_like(fgradinput); |
| |
| // run backward eagerly |
| at::Tensor grad_input, grad_weight, grad_bias; |
| std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight, |
| kernel_size, stride, padding, |
| finput, fgradinput, {true, true, true}); |
| |
| // make JIT graph |
| auto graph = std::make_shared<Graph>(); |
| auto ksz_val = graph->insertConstant(IValue(kernel_size)); |
| auto kst_val = graph->insertConstant(IValue(stride)); |
| auto pad_val = graph->insertConstant(IValue(padding)); |
| |
| auto inputg = graph->addInput("self"); |
| auto weightg = graph->addInput("weight"); |
| auto biasg = graph->addInput("bias"); |
| |
| Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); |
| auto outputs = conv->node()->outputs(); |
| for (auto output : outputs) { |
| graph->registerOutput(output); |
| } |
| LowerAllTuples(graph); |
| graph->lint(); |
| |
| // differentiate JIT graph |
| EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick |
| ConstantPropagation(graph); |
| auto grad_spec = differentiate(graph); |
| LowerGradOf(*grad_spec.df); |
| |
| // prepare JIT inputs / gradients |
| tensor_list tensors_in; |
| tensors_in.push_back(input); |
| tensors_in.push_back(weight); |
| tensors_in.push_back(bias); |
| |
| tensor_list tensor_grads_in; |
| tensor_grads_in.push_back(grad_output); |
| tensor_grads_in.push_back(grad_finput); |
| tensor_grads_in.push_back(grad_fgradinput); |
| |
| // Get outputs from the interpreter |
| tensor_list tensors_out, tensor_grads_out; |
| std::tie(tensors_out, tensor_grads_out) = |
| runGradient(grad_spec, tensors_in, tensor_grads_in); |
| |
| // prepare expected structs |
| tensor_list expected_tensors_out, expected_tensor_grads_out; |
| expected_tensors_out.push_back(output); |
| expected_tensors_out.push_back(finput); |
| expected_tensors_out.push_back(fgradinput); |
| expected_tensor_grads_out.push_back(grad_input); |
| expected_tensor_grads_out.push_back(grad_weight); |
| expected_tensor_grads_out.push_back(grad_bias); |
| |
| // Compare results |
| assertAllClose(tensors_out, expected_tensors_out); |
| assertAllClose(tensor_grads_out, expected_tensor_grads_out); |
| } |
| |
| void testATenNativeBatchNorm() { |
| // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) |
| std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W |
| bool training = true; |
| float momentum = 0.9; |
| float eps = 1e-5; |
| |
| // make inputs |
| at::Tensor input = torch::randn(input_size); |
| at::Tensor weight = torch::randn({input_size[1]}); |
| at::Tensor bias = torch::randn({input_size[1]}); |
| at::Tensor running_mean = torch::randn({input_size[1]}); |
| at::Tensor running_var = torch::randn({input_size[1]}); |
| |
| // running_mean and running_var are changed in-place, so clone and send them |
| at::Tensor running_mean_eager = running_mean.clone(); |
| at::Tensor running_var_eager = running_var.clone(); |
| at::Tensor running_mean_jit = running_mean.clone(); |
| at::Tensor running_var_jit = running_var.clone(); |
| |
| // run forward eagerly |
| at::Tensor output, savemean, saveinvstd; |
| std::tie(output, savemean, saveinvstd) = at::native_batch_norm(input, weight, bias, running_mean_eager, running_var_eager, training, momentum, eps); |
| |
| // make grad_outputs |
| at::Tensor grad_output = torch::randn_like(output); |
| at::Tensor grad_savemean = torch::zeros_like(savemean); |
| at::Tensor grad_saveinvstd = torch::zeros_like(saveinvstd); |
| |
| // run backward eagerly |
| at::Tensor grad_input, grad_weight, grad_bias; |
| // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) |
| std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(grad_output, input, weight, |
| running_mean_eager, running_var_eager, |
| savemean, saveinvstd, training, eps, {true, true, true}); |
| |
| // make JIT graph |
| auto graph = std::make_shared<Graph>(); |
| auto training_val = graph->insertConstant(IValue(training)); |
| auto momentum_val = graph->insertConstant(IValue(momentum)); |
| auto eps_val = graph->insertConstant(IValue(eps)); |
| |
| auto inputg = graph->addInput("self"); |
| auto weightg = graph->addInput("weight"); |
| auto biasg = graph->addInput("bias"); |
| auto running_meang = graph->addInput("running_mean"); |
| auto running_varg = graph->addInput("running_var"); |
| |
| Value* bn = graph->insert(aten::native_batch_norm, {inputg, weightg, biasg, running_meang, running_varg, training_val, momentum_val, eps_val}); |
| auto outputs = bn->node()->outputs(); |
| for (auto output : outputs) { |
| graph->registerOutput(output); |
| } |
| LowerAllTuples(graph); |
| graph->lint(); |
| |
| // differentiate JIT graph |
| EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick |
| ConstantPropagation(graph); |
| auto grad_spec = differentiate(graph); |
| LowerGradOf(*grad_spec.df); |
| |
| // prepare JIT inputs / gradients |
| tensor_list tensors_in; |
| tensors_in.push_back(input); |
| tensors_in.push_back(weight); |
| tensors_in.push_back(bias); |
| tensors_in.push_back(running_mean_jit); |
| tensors_in.push_back(running_var_jit); |
| |
| tensor_list tensor_grads_in; |
| tensor_grads_in.push_back(grad_output); |
| tensor_grads_in.push_back(grad_savemean); |
| tensor_grads_in.push_back(grad_saveinvstd); |
| |
| // Get outputs from the interpreter |
| tensor_list tensors_out, tensor_grads_out; |
| std::tie(tensors_out, tensor_grads_out) = |
| runGradient(grad_spec, tensors_in, tensor_grads_in); |
| |
| // prepare expected structs |
| tensor_list expected_tensors_out, expected_tensor_grads_out; |
| expected_tensors_out.push_back(output); |
| expected_tensors_out.push_back(savemean); |
| expected_tensors_out.push_back(saveinvstd); |
| expected_tensors_out.push_back(running_mean_eager); |
| expected_tensors_out.push_back(running_var_eager); |
| expected_tensor_grads_out.push_back(grad_input); |
| expected_tensor_grads_out.push_back(grad_weight); |
| expected_tensor_grads_out.push_back(grad_bias); |
| |
| tensors_out.push_back(running_mean_jit); |
| tensors_out.push_back(running_var_jit); |
| |
| // Compare results |
| assertAllClose(tensors_out, expected_tensors_out); |
| assertAllClose(tensor_grads_out, expected_tensor_grads_out); |
| } |
| |
| using var_meta_type = std::vector<int64_t>; |
| using var_meta_list = std::vector<var_meta_type>; |
| using test_fn_type = std::function<variable_list(const variable_list&)>; |
| |
| struct ADTestSpec { |
| ADTestSpec(const char* name, var_meta_list input_meta, test_fn_type test_fn) |
| : name(name), input_meta(input_meta), test_fn(test_fn) {} |
| |
| variable_list operator()(const variable_list& inputs) const { |
| return test_fn(inputs); |
| }; |
| |
| std::vector<Variable> make_vars() const { |
| std::vector<Variable> out; |
| for (const auto& m : input_meta) { |
| out.push_back(torch::randn(m, at::requires_grad(true))); |
| } |
| return out; |
| } |
| |
| const char* name; |
| var_meta_list input_meta; |
| test_fn_type test_fn; |
| }; |
| |
| variable_list get_grad_outputs(const variable_list& vars) { |
| return fmap(vars, [](const Variable& v) -> Variable { |
| return at::randn(v.sizes(), v.options()); |
| }); |
| } |
| |
| std::shared_ptr<Graph> trace( |
| const ADTestSpec& test, |
| const variable_list& vars_in) { |
| std::shared_ptr<tracer::TracingState> state; |
| Stack trace_stack_in; |
| std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in)); |
| variable_list trace_vars_in = fmap( |
| trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); }); |
| auto trace_vars_out = test(trace_vars_in); |
| tracer::exit(fmap<IValue>(trace_vars_out)); |
| return state->graph; |
| } |
| |
| variable_list grad( |
| const variable_list& outputs, |
| const variable_list& inputs, |
| const variable_list& grad_outputs) { |
| const auto get_edge = [](const Variable& v) { return v.gradient_edge(); }; |
| auto& engine = torch::autograd::Engine::get_default_engine(); |
| return engine.execute( |
| fmap(outputs, get_edge), |
| grad_outputs, |
| true, |
| false, |
| fmap(inputs, get_edge)); |
| } |
| |
| void testADFormulas() { |
| const auto unwrap = [](const Variable& v) { return v.data(); }; |
| |
| using VL = variable_list; |
| const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}}; |
| const var_meta_list unary_pointwise = {{2, 3, 4, 5}}; |
| const var_meta_list unary_pointwise_2d = {{2, 3}}; |
| const std::vector<ADTestSpec> ad_tests = { |
| {"add", |
| binary_pointwise, |
| [](const VL& v) -> VL { return {v[0] + v[1]}; }}, |
| {"sub", |
| binary_pointwise, |
| [](const VL& v) -> VL { return {v[0] - v[1]}; }}, |
| {"mul", |
| binary_pointwise, |
| [](const VL& v) -> VL { return {v[0] * v[1]}; }}, |
| {"sigmoid", |
| unary_pointwise, |
| [](const VL& v) -> VL { return {v[0].sigmoid()}; }}, |
| {"tanh", |
| unary_pointwise, |
| [](const VL& v) -> VL { return {v[0].tanh()}; }}, |
| {"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }}, |
| {"view", |
| unary_pointwise_2d, |
| [](const VL& v) -> VL { return {v[0].view({3, 2})}; }}, |
| {"expand", |
| {{2, 1}}, |
| [](const VL& v) -> VL { return {v[0].expand({2, 3})}; }}, |
| {"mm", |
| {{10, 12}, {12, 15}}, |
| [](const VL& v) -> VL { return {v[0].mm(v[1])}; }}, |
| // TODO: enable once we'll be able to capture lists across forward-backward |
| //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return |
| // fmap<Variable>(v[0].chunk(4, 1)); }}, |
| //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return |
| // fmap<Variable>(v[0].chunk(3, 2)); }}, |
| //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return |
| // fmap<Variable>(v[0].split(4, 1)); }}, |
| //{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return |
| // fmap<Variable>(v[0].split(3, 2)); }}, |
| }; |
| |
| for (const auto& test : ad_tests) { |
| // Get reference values form autograd |
| auto vars_in = test.make_vars(); |
| auto vars_out = test(vars_in); |
| auto var_grads_in = get_grad_outputs(vars_out); |
| auto var_grads_out = grad(vars_out, vars_in, var_grads_in); |
| |
| // Trace and differentiate the op |
| auto graph = trace(test, vars_in); |
| EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick |
| ConstantPropagation(graph); |
| auto grad_spec = differentiate(graph); |
| LowerGradOf(*grad_spec.df); |
| // Get outputs from the interpreter |
| auto tensors_in = fmap(vars_in, unwrap); |
| auto tensor_grads_in = fmap(var_grads_in, unwrap); |
| tensor_list tensors_out, tensor_grads_out; |
| std::tie(tensors_out, tensor_grads_out) = |
| runGradient(grad_spec, tensors_in, tensor_grads_in); |
| |
| // Compare results |
| auto expected_tensors_out = fmap(vars_out, unwrap); |
| auto expected_tensor_grads_out = fmap(var_grads_out, unwrap); |
| assertAllClose(tensors_out, expected_tensors_out); |
| assertAllClose(tensor_grads_out, expected_tensor_grads_out); |
| } |
| } |
| |
| std::string toString(std::shared_ptr<Graph>& graph) { |
| std::ostringstream s; |
| s << *graph; |
| return s.str(); |
| } |
| |
| void testDifferentiate(std::ostream& out = std::cout) { |
| auto graph = std::make_shared<Graph>(); |
| at::ScalarType s = at::ScalarType::Float; |
| auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); |
| |
| // Build up a fake graph |
| auto a = SymbolicVariable::asNewInput(*graph, type); |
| auto b = SymbolicVariable::asNewInput(*graph, type); |
| auto c = a * b * a + b; |
| graph->registerOutput(c.value()); |
| |
| auto grad_spec = differentiate(graph); |
| std::vector<size_t> expected_captured_inputs = {0, 1}; |
| std::vector<size_t> expected_captured_outputs = {1}; |
| std::vector<size_t> expected_input_vjps = {0, 1}; |
| std::vector<size_t> expected_output_vjps = {0, 1}; |
| ASSERT_EQ(grad_spec.f_real_outputs, 1); |
| ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs); |
| ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs); |
| ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); |
| ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); |
| out << "testDifferentiate\n"; |
| out << *grad_spec.f; |
| out << *grad_spec.df; |
| out << "\n"; |
| } |
| |
| void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { |
| // Build up a fake graph |
| auto graph = std::make_shared<Graph>(); |
| auto a = SymbolicVariable::asNewInput(*graph); |
| auto b = SymbolicVariable::asNewInput(*graph); |
| auto d = b * b + b; |
| auto e = (d + a) * a + b; |
| graph->registerOutput(d.value()); |
| graph->registerOutput(e.value()); |
| |
| auto a_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); |
| auto b_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false); |
| setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2)); |
| PropagateInputShapes(graph); |
| PropagateRequiresGrad(graph); |
| |
| auto grad_spec = differentiate(graph); |
| std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a) |
| std::vector<size_t> expected_output_vjps = {0}; // only a requires grad |
| ASSERT_EQ(grad_spec.f_real_outputs, 2); // we need one temporary %4 = (d + a) |
| ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0})); |
| ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2})); |
| ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); |
| ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); |
| out << "testDifferentiateWithRequiresGrad\n"; |
| out << *grad_spec.f; |
| out << *grad_spec.df; |
| out << "\n"; |
| } |
| |
| void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) { |
| auto graph = build_lstm(); |
| CreateAutodiffSubgraphs(graph, /*threshold=*/2); |
| out << "testCreateAutodiffSubgraphs\n"; |
| out << *graph << "\n"; |
| } |
| |
| void testSubgraphUtils() { |
| auto graph = build_lstm(); |
| EliminateCommonSubexpression(graph); |
| |
| std::vector<Node*> originalNodes( |
| graph->nodes().begin(), graph->nodes().end()); |
| |
| // Merge everything into a single subgraph |
| bool first = true; |
| Node* subgraph; |
| for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) { |
| if (first) { |
| subgraph = SubgraphUtils::createSingletonSubgraph( |
| *it, prim::DifferentiableGraph); |
| it = ++subgraph->reverseIterator(); |
| first = false; |
| } |
| |
| SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph); |
| it = ++subgraph->reverseIterator(); |
| } |
| |
| // Unmerge and compare with original node listing |
| SubgraphUtils::unmergeSubgraph(subgraph); |
| EliminateCommonSubexpression(graph); |
| |
| std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end()); |
| ASSERT_EQ(originalNodes.size(), newNodes.size()); |
| } |
| |
| autograd::Variable var(at::Type& t, at::IntList sizes, bool requires_grad) { |
| return autograd::make_variable(at::rand(sizes, t.options()), requires_grad); |
| } |
| autograd::Variable undef() { |
| return autograd::Variable(); |
| } |
| |
| int device(const autograd::Variable& v) { |
| return v.type().is_cuda() ? v.get_device() : -1; |
| } |
| |
| bool isEqual(at::IntList lhs, at::IntList rhs) { |
| return lhs.size() == rhs.size() && |
| std::equal(lhs.begin(), lhs.end(), rhs.begin()); |
| } |
| |
| bool isEqual(const CompleteArgumentInfo& ti, const autograd::Variable& v) { |
| if (!ti.defined()) |
| return ti.defined() == v.defined(); |
| return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() && |
| ti.type() == v.type().scalarType() && isEqual(ti.sizes(), v.sizes()) && |
| isEqual(ti.strides(), v.strides()); |
| } |
| |
| // work around the fact that variable_tensor_list doesn't duplicate all |
| // of std::vector's constructors. |
| // most constructors are never used in the implementation, just in our tests. |
| Stack createStack(std::vector<at::Tensor>&& list) { |
| return Stack( |
| std::make_move_iterator(list.begin()), |
| std::make_move_iterator(list.end())); |
| } |
| |
| void testArgumentSpec() { |
| auto& CF = at::CPU(at::kFloat); |
| auto& CD = at::CPU(at::kDouble); |
| auto& GF = at::CUDA(at::kFloat); |
| auto& GD = at::CUDA(at::kDouble); |
| |
| auto list = createStack({var(CF, {1}, true), |
| var(CD, {1, 2}, false), |
| var(GF, {}, true), |
| var(GD, {4, 5, 6}, false), |
| undef()}); |
| |
| // make sure we have some non-standard strides |
| list[1].toTensor().transpose_(0, 1); |
| |
| // same list but different backing values |
| auto list2 = createStack({var(CF, {1}, true), |
| var(CD, {1, 2}, false), |
| var(GF, {}, true), |
| var(GD, {4, 5, 6}, false), |
| undef()}); |
| list2[1].toTensor().transpose_(0, 1); |
| |
| CompleteArgumentSpec a(true, list); |
| CompleteArgumentSpec b(true, list); |
| ASSERT_EQ(a.hashCode(), b.hashCode()); |
| |
| ASSERT_EQ(a, b); |
| CompleteArgumentSpec d(true, list2); |
| ASSERT_EQ(d, a); |
| ASSERT_EQ(d.hashCode(), a.hashCode()); |
| |
| for (size_t i = 0; i < list.size(); ++i) { |
| ASSERT_TRUE(isEqual(a.at(i), list[i].toTensor())); |
| } |
| CompleteArgumentSpec no_grad(/*with_grad=*/false, list); |
| ASSERT_TRUE(no_grad != a); |
| |
| std::unordered_set<CompleteArgumentSpec> spec; |
| spec.insert(std::move(a)); |
| ASSERT_TRUE(spec.count(b) > 0); |
| ASSERT_EQ(spec.count(no_grad), 0); |
| spec.insert(std::move(no_grad)); |
| ASSERT_EQ(spec.count(CompleteArgumentSpec(true, list)), 1); |
| |
| list2[1].toTensor().transpose_(0, 1); |
| CompleteArgumentSpec c(true, list2); // same as list, except for one stride |
| ASSERT_FALSE(c == a); |
| ASSERT_EQ(spec.count(c), 0); |
| |
| Stack stack = {var(CF, {1, 2}, true), 3, var(CF, {1, 2}, true)}; |
| CompleteArgumentSpec with_const(true, stack); |
| ASSERT_EQ(with_const.at(2).sizes().size(), 2); |
| } |
| |
| void testGraphExecutor() { |
| constexpr int batch_size = 4; |
| constexpr int input_size = 256; |
| |
| int hidden_size = 2 * input_size; |
| |
| auto v = [](at::Tensor t) { return autograd::make_variable(t, false); }; |
| |
| auto input = at::randn({batch_size, input_size}, at::kCUDA); |
| auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); |
| auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); |
| auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); |
| auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); |
| |
| auto g = build_lstm(); |
| GraphExecutor executor(g); |
| auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)}); |
| executor.run(stack); |
| ASSERT_EQ(stack.size(), 2); |
| at::Tensor r0, r1; |
| std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh); |
| ASSERT_TRUE(almostEqual(Variable(stack[0].toTensor()).data(), r0)); |
| ASSERT_TRUE(almostEqual(Variable(stack[1].toTensor()).data(), r1)); |
| } |
| |
| void testBlocks(std::ostream& out = std::cout) { |
| Graph g; |
| auto a = Var::asNewInput(g, "a"); |
| auto b = Var::asNewInput(g, "b"); |
| auto c = a + b; |
| auto r = g.appendNode(g.create(prim::If, {Var::asNewInput(g, "c").value()})); |
| auto then_block = r->addBlock(); |
| auto else_block = r->addBlock(); |
| { |
| WithInsertPoint guard(then_block); |
| auto t = c + c; |
| then_block->registerOutput(t.value()); |
| } |
| { |
| WithInsertPoint guard(else_block); |
| auto d = b + c; |
| auto e = d + c; |
| else_block->registerOutput(e.value()); |
| } |
| g.registerOutput((Var(r->output()) + c).value()); |
| g.lint(); |
| out << "testBlocks\n" << g << "\n"; |
| r->eraseBlock(0); |
| out << g << "\n"; |
| g.lint(); |
| // test recursive copy of blocks works |
| auto g2 = g.copy(); |
| out << *g2 << "\n"; |
| } |
| |
| const auto cf_examples = R"JIT( |
| def if_test(a, b): |
| # FIXME: use 0 instead of a. |
| # c = 0 |
| c = a |
| if bool(a < b): |
| c = b |
| else: |
| c = a |
| return c |
| def if_one(a, b): |
| c = b |
| if bool(a < b): |
| c = a |
| return c |
| def while_test(a, i): |
| while bool(i < 3): |
| a *= a |
| i += 1 |
| return a |
| )JIT"; |
| void testControlFlow() { |
| auto cu = std::make_shared<script::Module>(); |
| script::defineMethodsInModule( |
| cu, cf_examples, script::nativeResolver, nullptr); |
| auto run = [&](const std::string& name, std::vector<IValue> stack) { |
| auto graph = cu->get_method(name).graph(); |
| Code code(graph); |
| InterpreterState interp(code); |
| interp.run(stack); |
| return stack; |
| }; |
| |
| auto L = [](int64_t l) { |
| return IValue(autograd::make_variable(scalar_to_tensor(at::Scalar(l)))); |
| }; |
| auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); }; |
| auto run_binary = [&](const std::string& name, int64_t a, int64_t b) { |
| return V(run(name, {L(a), L(b)})[0]); |
| }; |
| ASSERT_EQ(2, run_binary("if_test", 1, 2)); |
| ASSERT_EQ(3, run_binary("if_test", 3, 2)); |
| ASSERT_EQ(2, run_binary("if_one", 2, 3)); |
| ASSERT_EQ(2, run_binary("if_one", 3, 2)); |
| ASSERT_EQ(256, run_binary("while_test", 2, 0)); |
| } |
| |
| void testIValue() { |
| Shared<IntList> foo = IntList::create({3, 4, 5}); |
| ASSERT_EQ(foo.use_count(), 1); |
| IValue bar{foo}; |
| ASSERT_EQ(foo.use_count(), 2); |
| auto baz = bar; |
| ASSERT_EQ(foo.use_count(), 3); |
| auto foo2 = std::move(bar); |
| ASSERT_EQ(foo.use_count(), 3); |
| ASSERT_TRUE(foo2.isIntList()); |
| ASSERT_TRUE(bar.isNone()); |
| foo2 = IValue(4.0); |
| ASSERT_TRUE(foo2.isDouble()); |
| ASSERT_EQ(foo2.toDouble(), 4.0); |
| ASSERT_EQ(foo.use_count(), 2); |
| ASSERT_TRUE(ArrayRef<int64_t>(baz.toIntList()->elements()).equals({3, 4, 5})); |
| |
| auto move_it = std::move(baz).toIntList(); |
| ASSERT_EQ(foo.use_count(), 2); |
| ASSERT_TRUE(baz.isNone()); |
| IValue i(4); |
| ASSERT_TRUE(i.isInt()); |
| ASSERT_EQ(i.toInt(), 4); |
| IValue dlist(DoubleList::create({3.5})); |
| ASSERT_TRUE(dlist.isDoubleList()); |
| ASSERT_TRUE(ArrayRef<double>(std::move(dlist).toDoubleList()->elements()) |
| .equals({3.5})); |
| ASSERT_TRUE(dlist.isNone()); |
| dlist = IValue(DoubleList::create({3.4})); |
| ASSERT_TRUE(ArrayRef<double>(dlist.toDoubleList()->elements()).equals({3.4})); |
| IValue the_list(Tuple::create({IValue(3.4), IValue(4), IValue(foo)})); |
| ASSERT_EQ(foo.use_count(), 3); |
| ASSERT_TRUE(the_list.isTuple()); |
| auto first = std::move(the_list).toTuple()->elements().at(1); |
| ASSERT_EQ(first.toInt(), 4); |
| at::Tensor tv = at::rand({3, 4}); |
| IValue ten(tv); |
| ASSERT_EQ(tv.use_count(), 2); |
| auto ten2 = ten; |
| ASSERT_EQ(tv.use_count(), 3); |
| ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor())); |
| std::move(ten2).toTensor(); |
| ASSERT_EQ(tv.use_count(), 2); |
| } |
| |
| void testProto() { |
| ::ONNX_NAMESPACE::ModelProto proto; |
| proto.set_producer_name("foo"); |
| } |
| |
| void testCustomOperators() { |
| { |
| RegisterOperators reg({createOperator( |
| "foo::bar", [](double a, at::Tensor b) { return a + b; })}); |
| auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); |
| ASSERT_EQ(ops.size(), 1); |
| |
| auto& op = ops.front(); |
| ASSERT_EQ(op->schema().name(), "foo::bar"); |
| |
| ASSERT_EQ(op->schema().arguments().size(), 2); |
| ASSERT_EQ(op->schema().arguments()[0].name(), "_0"); |
| ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); |
| ASSERT_EQ(op->schema().arguments()[1].name(), "_1"); |
| ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType); |
| |
| ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType); |
| |
| Stack stack; |
| push(stack, 2.0f, autograd::make_variable(at::ones(5))); |
| op->getOperation()(stack); |
| at::Tensor output; |
| pop(stack, output); |
| |
| ASSERT_TRUE(output.allclose(autograd::make_variable(at::full(5, 3.0f)))); |
| } |
| { |
| RegisterOperators reg({createOperator( |
| "foo::bar_with_schema(float a, Tensor b) -> Tensor", |
| [](double a, at::Tensor b) { return a + b; })}); |
| |
| auto& ops = |
| getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); |
| ASSERT_EQ(ops.size(), 1); |
| |
| auto& op = ops.front(); |
| ASSERT_EQ(op->schema().name(), "foo::bar_with_schema"); |
| |
| ASSERT_EQ(op->schema().arguments().size(), 2); |
| ASSERT_EQ(op->schema().arguments()[0].name(), "a"); |
| ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); |
| ASSERT_EQ(op->schema().arguments()[1].name(), "b"); |
| ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType); |
| |
| ASSERT_EQ(op->schema().returns().size(), 1); |
| ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType); |
| |
| Stack stack; |
| push(stack, 2.0f, autograd::make_variable(at::ones(5))); |
| op->getOperation()(stack); |
| at::Tensor output; |
| pop(stack, output); |
| |
| ASSERT_TRUE(output.allclose(autograd::make_variable(at::full(5, 3.0f)))); |
| } |
| { |
| // Check that lists work well. |
| RegisterOperators reg({createOperator( |
| "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]", |
| [](const std::vector<int64_t>& ints, |
| const std::vector<double>& floats, |
| std::vector<at::Tensor> tensors) { return floats; })}); |
| |
| auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); |
| ASSERT_EQ(ops.size(), 1); |
| |
| auto& op = ops.front(); |
| ASSERT_EQ(op->schema().name(), "foo::lists"); |
| |
| ASSERT_EQ(op->schema().arguments().size(), 3); |
| ASSERT_EQ(op->schema().arguments()[0].name(), "ints"); |
| ASSERT_TRUE( |
| op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts())); |
| ASSERT_EQ(op->schema().arguments()[1].name(), "floats"); |
| ASSERT_TRUE( |
| op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats())); |
| ASSERT_EQ(op->schema().arguments()[2].name(), "tensors"); |
| ASSERT_TRUE( |
| op->schema().arguments()[2].type()->isSubtypeOf(ListType::ofTensors())); |
| |
| ASSERT_EQ(op->schema().returns().size(), 1); |
| ASSERT_TRUE( |
| op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats())); |
| |
| Stack stack; |
| push(stack, std::vector<int64_t>{1, 2}); |
| push(stack, std::vector<double>{1.0, 2.0}); |
| push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))}); |
| op->getOperation()(stack); |
| std::vector<double> output; |
| pop(stack, output); |
| |
| ASSERT_EQ(output.size(), 2); |
| ASSERT_EQ(output[0], 1.0); |
| ASSERT_EQ(output[1], 2.0); |
| } |
| { |
| RegisterOperators reg( |
| "foo::lists2(Tensor[] tensors) -> Tensor[]", |
| [](std::vector<at::Tensor> tensors) { return tensors; }); |
| |
| auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); |
| ASSERT_EQ(ops.size(), 1); |
| |
| auto& op = ops.front(); |
| ASSERT_EQ(op->schema().name(), "foo::lists2"); |
| |
| ASSERT_EQ(op->schema().arguments().size(), 1); |
| ASSERT_EQ(op->schema().arguments()[0].name(), "tensors"); |
| ASSERT_TRUE( |
| op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors())); |
| |
| ASSERT_EQ(op->schema().returns().size(), 1); |
| ASSERT_TRUE( |
| op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors())); |
| |
| Stack stack; |
| push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))}); |
| op->getOperation()(stack); |
| std::vector<at::Tensor> output; |
| pop(stack, output); |
| |
| ASSERT_EQ(output.size(), 1); |
| ASSERT_TRUE(output[0].allclose(autograd::make_variable(at::ones(5)))); |
| } |
| { |
| auto op = createOperator( |
| "traced::op(float a, Tensor b) -> Tensor", |
| [](double a, at::Tensor b) { return a + b; }); |
| |
| std::shared_ptr<tracer::TracingState> state; |
| std::tie(state, std::ignore) = tracer::enter({}); |
| |
| Stack stack; |
| push(stack, 2.0f, autograd::make_variable(at::ones(5))); |
| op.getOperation()(stack); |
| at::Tensor output = autograd::make_variable(at::empty({})); |
| pop(stack, output); |
| |
| tracer::exit({IValue(output)}); |
| |
| std::string op_name("traced::op"); |
| bool contains_traced_op = false; |
| for (const auto& node : state->graph->nodes()) { |
| if (std::string(node->kind().toQualString()) == op_name) { |
| contains_traced_op = true; |
| break; |
| } |
| } |
| ASSERT_TRUE(contains_traced_op); |
| } |
| { |
| ASSERT_THROWS_WITH( |
| createOperator( |
| "foo::bar_with_bad_schema(Tensor a) -> Tensor", |
| [](double a, at::Tensor b) { return a + b; }), |
| "Inferred 2 argument(s) for operator implementation, " |
| "but the provided schema specified 1 argument(s)."); |
| ASSERT_THROWS_WITH( |
| createOperator( |
| "foo::bar_with_bad_schema(Tensor a) -> Tensor", |
| [](double a) { return a; }), |
| "Inferred type for argument #0 was float, " |
| "but the provided schema specified type Dynamic " |
| "for the argument in that position"); |
| ASSERT_THROWS_WITH( |
| createOperator( |
| "foo::bar_with_bad_schema(float a) -> (float, float)", |
| [](double a) { return a; }), |
| "Inferred 1 return value(s) for operator implementation, " |
| "but the provided schema specified 2 return value(s)."); |
| ASSERT_THROWS_WITH( |
| createOperator( |
| "foo::bar_with_bad_schema(float a) -> Tensor", |
| [](double a) { return a; }), |
| "Inferred type for return value #0 was float, " |
| "but the provided schema specified type Dynamic " |
| "for the return value in that position"); |
| } |
| { |
| // vector<double> is not supported yet. |
| auto op = createOperator( |
| "traced::op(float[] f) -> int", |
| [](const std::vector<double>& f) -> int64_t { return f.size(); }); |
| |
| std::shared_ptr<tracer::TracingState> state; |
| std::tie(state, std::ignore) = tracer::enter({}); |
| |
| Stack stack; |
| push(stack, std::vector<double>{1.0}); |
| |
| ASSERT_THROWS_WITH( |
| op.getOperation()(stack), |
| "Tracing float lists currently not supported!"); |
| |
| tracer::abandon(); |
| } |
| } |
| |
| // test a few features that are not directly used in schemas yet |
| void testSchemaParser() { |
| // nested arrays |
| auto s = parseSchema("at::what(int[][4] foo) -> ()"); |
| ASSERT_TRUE(s.arguments().at(0).N() == 4); |
| ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments().at(0) |
| .type()->expect<ListType>() |
| ->getElementType() |
| ->expect<ListType>() |
| ->getElementType())); |
| auto s2 = parseSchema("at::what(int[][] foo) -> ()"); |
| ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments().at(0) |
| .type()->expect<ListType>() |
| ->getElementType() |
| ->expect<ListType>() |
| ->getElementType())); |
| |
| // named returns |
| parseSchema("at::what(Tensor! i_will_be_written_to) -> ()"); |
| auto s3 = parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)"); |
| ASSERT_TRUE(s3.returns().at(0).name() == "the_return"); |
| ASSERT_TRUE(s3.returns().at(1).name() == "the_return2"); |
| |
| // futures |
| auto s4 = parseSchema("at::what(Future(int) foo) -> ()"); |
| ASSERT_TRUE(IntType::get()->isSubtypeOf(s4.arguments().at(0) |
| .type()->expect<FutureType>() |
| ->getElementType())); |
| |
| // test tensor with annotated alias sets |
| parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))"); |
| |
| { |
| const auto s = parseSchema( |
| "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)" |
| " -> (Tensor(b|c)[](a!))"); |
| |
| // The list itself is annotated with `a` |
| const auto& aliasInfo = *s.arguments().at(0).alias_info(); |
| ASSERT_TRUE( |
| aliasInfo.sets() == |
| std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")}); |
| ASSERT_TRUE(aliasInfo.isWrite()); |
| |
| // Check the contained types |
| ASSERT_TRUE(!aliasInfo.containedTypes().empty()); |
| const auto& containedAliasInfo = aliasInfo.containedTypes()[0]; |
| const auto expected = std::unordered_set<Symbol>{ |
| Symbol::fromQualString("alias::b"), |
| Symbol::fromQualString("alias::c"), |
| }; |
| ASSERT_TRUE(containedAliasInfo.sets() == expected); |
| ASSERT_FALSE(containedAliasInfo.isWrite()); |
| } |
| } |
| |
| void testTopologicalIndex() { |
| { |
| Graph graph; |
| auto node1 = graph.create(prim::Undefined); |
| auto node2 = graph.create(prim::Undefined); |
| auto node3 = graph.create(prim::Undefined); |
| auto node4 = graph.create(prim::Undefined); |
| |
| graph.appendNode(node4); |
| graph.prependNode(node1); |
| node2->insertAfter(node1); |
| node3->insertBefore(node4); |
| |
| // nodes should be in numerical order |
| ASSERT_TRUE(node1->isBefore(node2)); |
| ASSERT_TRUE(node1->isBefore(node3)); |
| ASSERT_TRUE(node1->isBefore(node4)); |
| ASSERT_TRUE(node2->isAfter(node1)); |
| ASSERT_TRUE(node2->isBefore(node3)); |
| ASSERT_TRUE(node2->isBefore(node4)); |
| ASSERT_FALSE(node3->isBefore(node1)); |
| ASSERT_FALSE(node3->isBefore(node2)); |
| ASSERT_FALSE(node3->isAfter(node4)); |
| |
| // Built up a block structure |
| // node3 |
| // /\ ... |
| // A B block1 |
| // \ ... |
| // C block2 |
| auto block1 = node3->addBlock(); |
| auto A = graph.create(prim::Undefined); |
| block1->appendNode(A); |
| auto B = graph.create(prim::Undefined); |
| block1->appendNode(B); |
| auto block2 = B->addBlock(); |
| auto C = graph.create(prim::Undefined); |
| block2->appendNode(C); |
| |
| // Check isAfter on different block levels |
| ASSERT_TRUE(node1->isBefore(A)); |
| ASSERT_TRUE(A->isBefore(B)); |
| ASSERT_TRUE(A->isBefore(C)); |
| |
| // make sure things don't blow up on deletions |
| node2->destroy(); |
| auto node2p = graph.create(prim::Undefined); |
| node2p->insertAfter(node1); |
| ASSERT_TRUE(node1->isBefore(node2p)); |
| ASSERT_TRUE(node2p->isBefore(node3)); |
| } |
| { |
| // Induce reindexing to test that path |
| Graph graph; |
| std::map<size_t, Node*> nodes; |
| |
| auto anchor = graph.create(prim::Undefined); |
| graph.appendNode(anchor); |
| // Inserting to the same place a lot will trigger reindexing |
| for (auto i = 0; i < 100; ++i) { |
| auto n = graph.create(prim::Undefined); |
| n->insertAfter(anchor); |
| nodes[i] = n; |
| } |
| |
| // Nodes should be in reverse order |
| for (auto i = 0; i < 100; ++i) { |
| for (auto j = i + 1; j < 100; ++j) { |
| ASSERT_TRUE(nodes[i]->isAfter(nodes[j])); |
| } |
| } |
| } |
| } |
| |
| |
| std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() { |
| return std::unique_ptr<detail::DynamicDAG<std::string>>(new detail::DynamicDAG<std::string>()); |
| } |
| |
| void testNewVertex() { |
| auto graph = newDynamicDAG(); |
| JIT_ASSERT(graph->debugNumVertices() == 0); |
| auto a = graph->newVertex("a"); |
| JIT_ASSERT(graph->debugNumVertices() == 1); |
| JIT_ASSERT(a->ord == 0); |
| JIT_ASSERT(a->data.size() == 1); |
| JIT_ASSERT(a->data[0] == "a"); |
| JIT_ASSERT(a->in_edges().size() == 0); |
| JIT_ASSERT(a->out_edges().size() == 0); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| JIT_ASSERT(graph->debugNumVertices() == 3); |
| JIT_ASSERT(b->ord == 1); |
| JIT_ASSERT(c->ord == 2); |
| } |
| |
| void testAddEdgeBasic() { |
| // a -> b -> c |
| // \---------^ |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| graph->addEdge(a, b); |
| graph->addEdge(b, c); |
| graph->addEdge(a, c); |
| JIT_ASSERT(a->in_edges().size() == 0); |
| JIT_ASSERT(a->out_edges().size() == 2); |
| JIT_ASSERT(a->out_edges().contains(b)); |
| JIT_ASSERT(a->out_edges().contains(c)); |
| JIT_ASSERT(b->in_edges().size() == 1); |
| JIT_ASSERT(b->out_edges().size() == 1); |
| JIT_ASSERT(b->in_edges().contains(a)); |
| JIT_ASSERT(b->out_edges().contains(c)); |
| JIT_ASSERT(c->in_edges().size() == 2); |
| JIT_ASSERT(c->out_edges().size() == 0); |
| JIT_ASSERT(c->in_edges().contains(a)); |
| JIT_ASSERT(c->in_edges().contains(b)); |
| } |
| |
| void testAddEdgeCycleDetection() { |
| // a -> b -> c |
| // ^---------/ |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| graph->addEdge(a, b); |
| graph->addEdge(b, c); |
| bool erred = false; |
| try { |
| graph->addEdge(c, a); |
| } catch (c10::Error& err) { |
| erred = true; |
| } |
| JIT_ASSERT(erred); |
| } |
| |
| void testAddEdgeReordersBasic() { |
| // a, b => b -> a |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| JIT_ASSERT(a->ord == 0); |
| JIT_ASSERT(b->ord == 1); |
| graph->addEdge(b, a); |
| JIT_ASSERT(a->ord == 1); |
| JIT_ASSERT(b->ord == 0); |
| } |
| |
| void testAddEdgeReordersComplicated() { |
| // a -> b c -> d with addEdge(d, b) ==> |
| // c -> d -> a -> b |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| auto d = graph->newVertex("d"); |
| graph->addEdge(a, b); |
| graph->addEdge(c, d); |
| JIT_ASSERT(a->ord == 0); |
| JIT_ASSERT(b->ord == 1); |
| JIT_ASSERT(c->ord == 2); |
| JIT_ASSERT(d->ord == 3); |
| graph->addEdge(d, a); |
| JIT_ASSERT(c->ord == 0); |
| JIT_ASSERT(d->ord == 1); |
| JIT_ASSERT(a->ord == 2); |
| JIT_ASSERT(b->ord == 3); |
| JIT_ASSERT(c->in_edges().size() == 0); |
| JIT_ASSERT(c->out_edges().size() == 1); |
| JIT_ASSERT(c->out_edges().contains(d)); |
| JIT_ASSERT(d->in_edges().size() == 1); |
| JIT_ASSERT(d->out_edges().size() == 1); |
| JIT_ASSERT(d->in_edges().contains(c)); |
| JIT_ASSERT(d->out_edges().contains(a)); |
| JIT_ASSERT(a->in_edges().size() == 1); |
| JIT_ASSERT(a->out_edges().size() == 1); |
| JIT_ASSERT(a->in_edges().contains(d)); |
| JIT_ASSERT(a->out_edges().contains(b)); |
| JIT_ASSERT(b->in_edges().size() == 1); |
| JIT_ASSERT(b->out_edges().size() == 0); |
| JIT_ASSERT(b->in_edges().contains(a)); |
| } |
| |
| void testRemoveEdgeBasic() { |
| // a -> b |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| graph->addEdge(a, b); |
| JIT_ASSERT(graph->debugNumVertices() == 2); |
| graph->removeEdge(a, b); |
| JIT_ASSERT(graph->debugNumVertices() == 2); |
| JIT_ASSERT(a->out_edges().size() == 0); |
| JIT_ASSERT(b->in_edges().size() == 0); |
| } |
| |
| void testRemoveVertexBasic() { |
| // a -> b |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| graph->addEdge(a, b); |
| graph->addEdge(b, c); |
| JIT_ASSERT(graph->debugNumVertices() == 3); |
| graph->removeVertex(b); |
| JIT_ASSERT(graph->debugNumVertices() == 2); |
| JIT_ASSERT(a->out_edges().size() == 0); |
| JIT_ASSERT(c->in_edges().size() == 0); |
| } |
| |
| void testContractEdgeBasic() { |
| // a -> b -> c -> d |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| auto d = graph->newVertex("d"); |
| graph->addEdge(a, b); |
| graph->addEdge(b, c); |
| graph->addEdge(c, d); |
| graph->contractEdge(b, c); |
| JIT_ASSERT(graph->debugNumVertices() == 3); |
| JIT_ASSERT(a->out_edges().size() == 1); |
| JIT_ASSERT(d->in_edges().size() == 1); |
| JIT_ASSERT(*a->out_edges().begin() == *d->in_edges().begin()); |
| auto* contracted = *a->out_edges().begin(); |
| JIT_ASSERT(contracted->data.size() == 2); |
| JIT_ASSERT(contracted->data[0] == "b"); |
| JIT_ASSERT(contracted->data[1] == "c"); |
| JIT_ASSERT(contracted->out_edges().size() == 1); |
| JIT_ASSERT(contracted->in_edges().size() == 1); |
| JIT_ASSERT(contracted->in_edges().contains(a)); |
| JIT_ASSERT(contracted->out_edges().contains(d)); |
| } |
| |
| void testContractEdgeCycleDetection() { |
| // a -> b -> c |
| // `---------^ |
| // contractEdge(a, c) will cause a cycle |
| auto graph = newDynamicDAG(); |
| auto a = graph->newVertex("a"); |
| auto b = graph->newVertex("b"); |
| auto c = graph->newVertex("c"); |
| graph->addEdge(a, b); |
| graph->addEdge(b, c); |
| graph->addEdge(a, c); |
| JIT_ASSERT(!graph->contractEdge(a, c)); |
| } |
| |
| void testDynamicDAG() { |
| testNewVertex(); |
| testAddEdgeBasic(); |
| testAddEdgeCycleDetection(); |
| testAddEdgeReordersBasic(); |
| testAddEdgeReordersComplicated(); |
| testRemoveEdgeBasic(); |
| testRemoveVertexBasic(); |
| testContractEdgeBasic(); |
| testContractEdgeCycleDetection(); |
| } |
| |
| // Fixture to set up a graph and make assertions clearer |
| struct TopoMoveTestFixture { |
| TopoMoveTestFixture() { |
| createGraph(); |
| aliasDb = AliasAnalysis(graph); |
| } |
| |
| // Nodes are named after their output. |
| // e.g. "a" is an alias for "the node that outputs the value `a`" |
| void createGraph() { |
| graph = std::make_shared<Graph>(); |
| createNode("a", {}); |
| createNode("b", {"a"}); |
| createNode("c", {}); |
| createNode("d", {"a", "b"}); |
| createNode("e", {"c", "b"}); |
| createNode("f", {"e"}); |
| createNode("g", {"e"}); |
| createNode("h", {"g"}); |
| createNode("i", {"g"}); |
| createNode("j", {"i"}); |
| createNode("k", {"i"}); |
| createNode("l", {"a"}); |
| createNode("m", {}, {"l"}); // block depends on l |
| createNode("n", {"m"}); |
| createNode("o", {"n"}); |
| createNode("p", {}); |
| createNode("q", {}); |
| createNode("r", {"q"}); |
| createNode("s", {"q"}); |
| |
| graph->lint(); |
| } |
| |
| void createNode( |
| const std::string& name, |
| const std::vector<std::string>& inputNames, |
| const std::vector<std::string>& blockInputNames = {}) { |
| std::vector<Value*> inputs; |
| for (const auto name : inputNames) { |
| inputs.push_back(nodes.at(name)->output()); |
| } |
| auto node = graph->appendNode(graph->create(prim::Undefined, inputs)); |
| node->output()->setUniqueName(name); |
| nodes[name] = node; |
| |
| if (blockInputNames.size() != 0) { |
| node->addBlock(); |
| std::vector<Value*> blockDeps; |
| for (const auto name : blockInputNames) { |
| blockDeps.push_back(nodes.at(name)->output()); |
| } |
| |
| auto block = node->blocks().at(0); |
| block->appendNode(graph->create(prim::Undefined, blockDeps)); |
| } |
| } |
| |
| bool moveBeforeTopologicallyValid( |
| const std::string& toInsert, |
| const std::string& insertPoint) { |
| std::function<bool(Node*, Node*)> func = [this](Node* toInsert, |
| Node* insertPoint) { |
| return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb); |
| }; |
| return moveWithChecks(toInsert, insertPoint, func); |
| } |
| |
| bool moveAfterTopologicallyValid( |
| const std::string& toInsert, |
| const std::string& insertPoint) { |
| std::function<bool(Node*, Node*)> func = [this](Node* toInsert, |
| Node* insertPoint) { |
| return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb); |
| }; |
| return moveWithChecks(toInsert, insertPoint, func); |
| } |
| |
| bool moveWithChecks( |
| const std::string& toInsert, |
| const std::string& insertPoint, |
| std::function<bool(Node*, Node*)> func) { |
| auto n = nodes.at(toInsert); |
| auto insert = nodes.at(insertPoint); |
| bool isAfter = n->isAfter(insert); |
| |
| std::vector<Node*> originalOrdering; |
| Node* original = isAfter ? n->next() : n->prev(); |
| |
| auto curNode = original; |
| while (curNode != n->owningBlock()->return_node()) { |
| originalOrdering.push_back(curNode); |
| if (isAfter) { |
| curNode = curNode->next(); |
| } else { |
| curNode = curNode->prev(); |
| } |
| } |
| |
| const auto couldMove = func(n, insert); |
| // Check the graph is okay |
| graph->lint(); |
| |
| // If this is the picture of nodes |
| // <some nodes> ... toInsert ... <some more nodes> ... insertPoint |
| // ^----------^ check that these nodes haven't moved |
| curNode = original; |
| size_t idx = 0; |
| while (curNode != n->owningBlock()->return_node()) { |
| JIT_ASSERT(originalOrdering[idx] == curNode); |
| if (isAfter) { |
| curNode = curNode->next(); |
| } else { |
| curNode = curNode->prev(); |
| } |
| idx++; |
| } |
| |
| return couldMove; |
| } |
| |
| void checkPostCondition( |
| const std::string& toInsert, |
| const std::string& insertPoint, |
| bool after) { |
| if (after) { |
| JIT_ASSERT(nodes.at(toInsert)->prev() == nodes.at(insertPoint)); |
| } else { |
| JIT_ASSERT(nodes.at(toInsert)->next() == nodes.at(insertPoint)); |
| } |
| } |
| |
| std::shared_ptr<Graph> graph; |
| c10::optional<AliasDb> aliasDb; |
| std::unordered_map<std::string, Node*> nodes; |
| }; |
| |
| void testTopologicalMove() { |
| { |
| // Check that we are removing `this`'s deps properly when we need to split |
| // `this` and deps (see code for what the hell that means) |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("q", "s")); |
| fixture.checkPostCondition("q", "s", false); |
| } |
| // Move after |
| { |
| // Simple move backward |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveAfterTopologicallyValid("c", "a")); |
| fixture.checkPostCondition("c", "a", true); |
| } |
| { |
| // simple invalid move backward |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(!fixture.moveAfterTopologicallyValid("d", "a")); |
| } |
| { |
| // doesn't actually move anything |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveAfterTopologicallyValid("f", "e")); |
| fixture.checkPostCondition("f", "e", true); |
| } |
| { |
| // move backward with multiple dependencies |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveAfterTopologicallyValid("e", "c")); |
| fixture.checkPostCondition("e", "c", true); |
| } |
| { |
| // Move backward with non-zero working set |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveAfterTopologicallyValid("k", "f")); |
| fixture.checkPostCondition("k", "f", true); |
| } |
| { |
| // Simple move forward |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveAfterTopologicallyValid("c", "d")); |
| fixture.checkPostCondition("c", "d", true); |
| } |
| { |
| // Move forward with non-zero working set |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveAfterTopologicallyValid("f", "l")); |
| fixture.checkPostCondition("f", "l", true); |
| } |
| |
| // Move before |
| { |
| // Simple move forward |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("b", "d")); |
| fixture.checkPostCondition("b", "d", false); |
| } |
| { |
| // Simple move backward |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("c", "a")); |
| fixture.checkPostCondition("c", "a", false); |
| } |
| { |
| // doesn't actually move anything |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("a", "b")); |
| fixture.checkPostCondition("a", "b", false); |
| } |
| { |
| // move forward with deps |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("f", "m")); |
| fixture.checkPostCondition("f", "m", false); |
| } |
| { |
| // move backward with deps |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("l", "f")); |
| fixture.checkPostCondition("l", "f", false); |
| } |
| |
| // check that dependencies in blocks are recognized |
| { |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(!fixture.moveAfterTopologicallyValid("l", "m")); |
| JIT_ASSERT(!fixture.moveBeforeTopologicallyValid("m", "l")); |
| JIT_ASSERT(!fixture.moveAfterTopologicallyValid("n", "l")); |
| JIT_ASSERT(!fixture.moveBeforeTopologicallyValid("l", "n")); |
| } |
| |
| // Test that moveAfter(n) and moveBefore(n->next()) are not necessarily |
| // equivalent. Here, the dependency ordering is n -> o -> p. So we can't |
| // move `n` after `o`, but we can move `n` before `p` (which pushes `o` after |
| // `p`) |
| { |
| TopoMoveTestFixture fixture; |
| JIT_ASSERT(!fixture.moveAfterTopologicallyValid("n", "o")); |
| JIT_ASSERT(fixture.moveBeforeTopologicallyValid("o", "p")); |
| fixture.checkPostCondition("o", "p", false); |
| } |
| } |
| |
| void testAliasAnalysis() { |
| { |
| auto graph = std::make_shared<Graph>(); |
| auto a = graph->addInput(); |
| auto b = graph->addInput(); |
| |
| // addsB = b + b |
| // c = a + b |
| // a += b |
| // d = c + c |
| auto addsB = graph->insert(aten::add, {b, b}); |
| auto c = graph->insert(aten::add, {a, b}); |
| auto aMut = graph->insert(aten::add_, {a, b}); |
| auto d = graph->insert(aten::add, {c, c}); |
| |
| graph->lint(); |
| |
| const auto aliasDb = AliasAnalysis(graph); |
| // Can't move past a mutation of a used value |
| JIT_ASSERT(!c->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb)); |
| JIT_ASSERT(d->node()->moveAfterTopologicallyValid(c->node(), aliasDb)); |
| |
| // b should alias to a (since they are both inputs) |
| JIT_ASSERT( |
| !addsB->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb)); |
| JIT_ASSERT(addsB->node()->moveAfterTopologicallyValid(c->node(), aliasDb)); |
| |
| graph->lint(); |
| } |
| { |
| auto graph = std::make_shared<Graph>(); |
| auto a = graph->addInput(); |
| auto b = graph->addInput(); |
| |
| auto constant = graph->insertConstant(1); |
| auto fresh = graph->insert(aten::rand, {constant}); |
| auto usesB = graph->insert(aten::add, {b, fresh}); |
| auto aliasesB = graph->insert(aten::select, {a, constant, constant}); |
| auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh}); |
| auto c = graph->insert(aten::add, {fresh, aliasesB}); |
| graph->lint(); |
| |
| const auto aliasDb = AliasAnalysis(graph); |
| |
| JIT_ASSERT(!aliasesB->node()->moveAfterTopologicallyValid( |
| mutatesAliasOfB->node(), aliasDb)); |
| JIT_ASSERT(!usesB->node()->moveAfterTopologicallyValid( |
| mutatesAliasOfB->node(), aliasDb)); |
| } |
| } |
| } // namespace |
| } // namespace jit |
| } // namespace torch |