| // *** Tensor Expressions *** |
| // |
| // This tutorial covers basics of NNC's tensor expressions, shows basic APIs to |
| // work with them, and outlines how they are used in the overall TorchScript |
| // compilation pipeline. This doc is permanently a "work in progress" since NNC |
| // is under active development and things change fast. |
| // |
| // This Tutorial's code is compiled in the standard pytorch build, and the |
| // executable can be found in `build/bin/tutorial_tensorexpr`. |
| // |
| // *** What is NNC *** |
| // |
| // NNC stands for Neural Net Compiler. It is a component of TorchScript JIT |
| // and it performs on-the-fly code generation for kernels, which are often a |
| // combination of multiple aten (torch) operators. |
| // |
| // When the JIT interpreter executes a torchscript model, it automatically |
| // extracts subgraphs from the torchscript IR graph for which specialized code |
| // can be JIT generated. This usually improves performance as the 'combined' |
| // kernel created from the subgraph could avoid unnecessary memory traffic that |
| // is unavoidable when the subgraph is interpreted as-is, operator by operator. |
| // This optimization is often referred to as 'fusion'. Relatedly, the process of |
| // finding and extracting subgraphs suitable for NNC code generation is done by |
| // a JIT pass called 'fuser'. |
| // |
| // *** What is TE *** |
| // |
| // TE stands for Tensor Expressions. TE is a commonly used approach for |
| // compiling kernels performing tensor (~matrix) computation. The idea behind it |
| // is that operators are represented as a mathematical formula describing what |
| // computation they do (as TEs) and then the TE engine can perform mathematical |
| // simplification and other optimizations using those formulas and eventually |
| // generate executable code that would produce the same results as the original |
| // sequence of operators, but more efficiently. |
| // |
| // NNC's design and implementation of TE was heavily inspired by Halide and TVM |
| // projects. |
| #include <iostream> |
| #include <string> |
| |
| #include <c10/util/irange.h> |
| #include <torch/csrc/jit/ir/ir.h> |
| #include <torch/csrc/jit/ir/irparser.h> |
| #include <torch/csrc/jit/tensorexpr/eval.h> |
| #include <torch/csrc/jit/tensorexpr/expr.h> |
| #include <torch/csrc/jit/tensorexpr/ir.h> |
| #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
| #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
| #include <torch/csrc/jit/tensorexpr/kernel.h> |
| #include <torch/csrc/jit/tensorexpr/loopnest.h> |
| #include <torch/csrc/jit/tensorexpr/stmt.h> |
| #include <torch/csrc/jit/tensorexpr/tensor.h> |
| #include <torch/torch.h> |
| |
| using namespace torch::jit::tensorexpr; |
| |
| #ifdef TORCH_ENABLE_LLVM |
| |
| // Helper function to print a snippet from a big multi-line string |
| static void printLinesToFrom(const std::string& input_str, int from, int to); |
| |
| #endif |
| |
| int main(int argc, char* argv[]) { |
| std::cout << "*** Structure of tensor expressions and statements ***" |
| << std::endl; |
| { |
| // A tensor expression is a tree of expressions. Each expression has a type, |
| // and that type defines what sub-expressions the current expression has. |
| // For instance, an expression of type 'Mul' would have a type 'kMul' and |
| // two subexpressions: LHS and RHS. Each of these two sub-expressions could |
| // also be a 'Mul' or some other expression. |
| // |
| // Let's construct a simple TE: |
| ExprPtr lhs = alloc<IntImm>(5); |
| ExprPtr rhs = alloc<Var>("x", kInt); |
| ExprPtr mul = alloc<Mul>(lhs, rhs); |
| std::cout << "Tensor expression: " << *mul << std::endl; |
| // Prints: Tensor expression: 5 * x |
| |
| // Here we created an expression representing a 5*x computation, where x is |
| // an int variable. |
| |
| // Another, probably a more convenient, way to construct tensor expressions |
| // is to use so called expression handles (as opposed to raw expressions |
| // like we did in the previous example). Expression handles overload common |
| // operations and allow us to express the same semantics in a more natural |
| // way: |
| ExprHandle l = 5; |
| ExprHandle r = Var::make("x", kInt); |
| ExprHandle m = l * r; |
| std::cout << "Tensor expression: " << *m.node() << std::endl; |
| // Prints: Tensor expression: 5 * x |
| |
| // Converting from handles to raw expressions and back is easy: |
| ExprHandle handle = Var::make("x", kInt); |
| ExprPtr raw_expr_from_handle = handle.node(); |
| ExprPtr raw_expr = alloc<Var>("x", kInt); |
| ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); |
| |
| // We could construct arbitrarily complex expressions using mathematical |
| // and logical operations, casts between various data types, and a bunch of |
| // intrinsics. |
| ExprHandle a = Var::make("a", kInt); |
| ExprHandle b = Var::make("b", kFloat); |
| ExprHandle c = Var::make("c", kFloat); |
| ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); |
| std::cout << "Tensor expression: " << *x.node() << std::endl; |
| // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) |
| |
| // An ultimate purpose of tensor expressions is to optimize tensor |
| // computations, and in order to represent accesses to tensors data, there |
| // is a special kind of expression - a load. |
| // To construct a load we need two pieces: the base and the indices. The |
| // base of a load is a Buf expression, which could be thought of as a |
| // placeholder similar to Var, but with dimensions info. |
| // |
| // Let's construct a simple load: |
| BufHandle A("A", {64, 32}, kInt); |
| VarPtr i_var = alloc<Var>("i", kInt), j_var = alloc<Var>("j", kInt); |
| ExprHandle i(i_var), j(j_var); |
| ExprHandle load = Load::make(A.dtype(), A, {i, j}); |
| std::cout << "Tensor expression: " << *load.node() << std::endl; |
| // Prints: Tensor expression: A[i, j] |
| |
| // Tensor Expressions constitute Tensor Statements, which are used to |
| // represent computation of a given operator or a group of operators from a |
| // fusion group. |
| // |
| // There are three main kinds of tensor statements: |
| // - block |
| // - store |
| // - loop |
| // |
| // A Store represents a store to a single element of a tensor (or to a |
| // group of elements if it's a vectorized store). Store statements, |
| // similarly to Load expressions, have a base and indices, but on top of |
| // that they also include a value - an expression representing what needs |
| // to be stored at the given memory location. Let's create a Store stmt: |
| StmtPtr store_a = Store::make(A, {i, j}, i + j); |
| std::cout << "Store statement: " << *store_a << std::endl; |
| // Prints: Store statement: A[i, j] = i + j; |
| |
| // An operator fills the entire tensor, not just a single element, and to |
| // represent this we need to use For stmt: let's wrap our store stmt with |
| // two nested loops to represent that variables i and j need to iterate |
| // over some ranges. |
| ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); |
| ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); |
| |
| std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; |
| // Prints: |
| // Nested for loops: |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // A[i, j] = i + j; |
| // } |
| // } |
| |
| // A Block statement is used when we need a sequence of other statements. |
| // E.g. if a fusion group contains several operators, we initially define |
| // separate loopnest for each of them and put them all into a common block: |
| BufHandle B("B", {64, 32}, kInt); |
| StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); |
| ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); |
| ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); |
| |
| BlockPtr block = Block::make({loop_i_a, loop_i_b}); |
| std::cout << "Compound Block statement: " << std::endl |
| << *block << std::endl; |
| // Prints: |
| // Compound Block statement: |
| // { |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // A[i, j] = i + j; |
| // } |
| // } |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // B[i, j] = A[i, j]; |
| // } |
| // } |
| // } |
| |
| // Manually constructing nested loops and blocks to represent a computation |
| // might be laborious, and instead we can use a 'Compute' API. This API |
| // requires us to specify dimensions and a lambda to compute a single |
| // element of the resulting tensor and returns a `Tensor` structure. This |
| // structure is simply a pair of a buffer that was created to represent the |
| // result of the computation (BufPtr) and a statement representing the |
| // computation itself (StmtPtr). |
| Tensor C = |
| Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { |
| return i * j; |
| }); |
| std::cout << "Stmt produced by 'Compute' API: " << std::endl |
| << *C.stmt() << std::endl; |
| // Prints: |
| // Stmt produced by 'Compute' API: |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // C[i, j] = i * j; |
| // } |
| // } |
| |
| // To construct statements to represent computations with reductions, we |
| // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple |
| // of extra arguments defining how to perform the reduction. Let's define a |
| // simple 2D sum of C using that: |
| Tensor D = Reduce( |
| "D", |
| {}, |
| Sum(), |
| [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, |
| {64, 32}); |
| std::cout << "Stmt produced by 'Reduce' API: " << std::endl |
| << *D.stmt() << std::endl; |
| } |
| |
| std::cout << "*** Loopnests transformations ***" << std::endl; |
| { |
| // When a statement for the computation is generated, we might want to |
| // apply some optimizations to it. These transformations allow us to end up |
| // with a statement producing the same results, but more efficiently. |
| // |
| // Let's look at a couple of transformations that are used in NNC. We will |
| // begin with constructing a Block statement like we did before. |
| |
| Tensor C = |
| Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { |
| return i * (j + 1); |
| }); |
| BufHandle c_buf(C.buf()); |
| Tensor D = |
| Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { |
| return c_buf.load(i, j) - i; |
| }); |
| StmtPtr block = Block::make({C.stmt(), D.stmt()}); |
| std::cout << "Stmt produced by 'Compute' API: " << std::endl |
| << *block << std::endl; |
| // Prints: |
| // Stmt produced by 'Compute' API: |
| // { |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // C[i, j] = i * (j + 1); |
| // } |
| // } |
| // for (const auto i_1 : c10::irange(64)) { |
| // for (const auto j_1 : c10::irange(32)) { |
| // D[i_1, j_1] = (C[i_1, j_1]) - i_1; |
| // } |
| // } |
| // } |
| |
| // One transformation we can apply to this computation is inlining: i.e. |
| // taking the expression that defines values of C and substituting a load |
| // from C with it. |
| // To do that, we first need to create a special object called LoopNest - |
| // all transformations are methods of this class. To create a loopnest we |
| // need to provide a list of output buffers and the root statement: |
| LoopNest nest(block, {D.buf()}); |
| |
| // We can always retrieve the Stmt back from LoopNest: |
| std::cout << "LoopNest root stmt: " << std::endl |
| << *nest.root_stmt() << std::endl; |
| // Prints: |
| // LoopNest root stmt: |
| // { |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // C[i, j] = i * (j + 1); |
| // } |
| // } |
| // for (const auto i_1 : c10::irange(64)) { |
| // for (const auto j_1 : c10::irange(32)) { |
| // D[i_1, j_1] = (C[i_1, j_1]) - i_1; |
| // } |
| // } |
| // } |
| |
| // Now we can apply the inlining transformation: |
| nest.computeInline(C.buf()); |
| std::cout << "Stmt after inlining:" << std::endl |
| << *nest.root_stmt() << std::endl; |
| // Prints: |
| // Stmt after inlining: |
| // { |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // D[i, j] = i * (j + 1) - i; |
| // } |
| // } |
| // } |
| |
| // We can also apply algebraic simplification to a statement: |
| StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); |
| std::cout << "Stmt after simplification:" << std::endl |
| << *simplified << std::endl; |
| // Prints: |
| // Stmt after simplification: |
| // { |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // D[i, j] = i * j; |
| // } |
| // } |
| // } |
| |
| // Many loopnest transformations are stateless and can be applied without |
| // creating a LoopNest object. In fact, we plan to make all transformations |
| // stateless. |
| // splitWithTail is one such transformation: it splits an iteration space |
| // of a given loop into two with a given factor. |
| ForPtr outer_loop = to<For>(to<Block>(simplified)->stmts().front()); |
| LoopNest::splitWithTail(outer_loop, 13); |
| // Call simplifier once more to fold some arithmetic. |
| simplified = IRSimplifier::simplify(simplified); |
| std::cout << "Stmt after splitWithTail:" << std::endl |
| << *simplified << std::endl; |
| // Prints: |
| // Stmt after splitWithTail: |
| // { |
| // for (const auto i_outer : c10::irange(4)) { |
| // for (const auto i_inner : c10::irange(13)) { |
| // for (const auto j : c10::irange(32)) { |
| // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); |
| // } |
| // } |
| // } |
| // for (const auto i_tail : c10::irange(12)) { |
| // for (const auto j : c10::irange(32)) { |
| // D[i_tail + 52, j] = i_tail * j + 52 * j; |
| // } |
| // } |
| // } |
| |
| // NNC supports a wide range of loop nest transformations, which we are not |
| // listing here. Please refer to documentation in |
| // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h |
| // for more details. |
| } |
| |
| std::cout << "*** Codegen ***" << std::endl; |
| { |
| // An ultimate goal of tensor expressions is to be provide a mechanism to |
| // execute a given computation in the fastest possible way. So far we've |
| // looked at how we could describe what computation we're interested in, but |
| // we haven't looked at how to actually execute it. |
| // |
| // All we've been dealing with was just symbols with no actual data |
| // associated, in this section we would look at how we can bridge that gap. |
| |
| // Let's start by constructing a simple computation for us to work with: |
| BufHandle A("A", {64, 32}, kInt); |
| BufHandle B("B", {64, 32}, kInt); |
| Tensor X = |
| Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i, j) + B.load(i, j); |
| }); |
| |
| // And let's lower it to a loop nest, as we did in the previous section. We |
| // can pass Tensor object directly: |
| LoopNest loopnest({X}); |
| std::cout << *loopnest.root_stmt() << std::endl; |
| // Prints: |
| // { |
| // for (const auto i : c10::irange(64)) { |
| // for (const auto j : c10::irange(32)) { |
| // X[i, j] = (A[i, j]) + (B[i, j]); |
| // } |
| // } |
| |
| // Now imagine that we have two actual tensors 64x32 that we want sum |
| // together, how do we pass those tensors to the computation and how do we |
| // carry it out? |
| // |
| // Codegen object is aimed at providing exactly that functionality. Codegen |
| // is an abstract class and concrete codegens are derived from it. |
| // Currently, we have three codegens: |
| // 1) Simple Evaluator, |
| // 2) LLVM Codegen for CPU, |
| // 3) CUDA Codegen. |
| // In this example we will be using Simple Evaluator, since it's available |
| // everywhere. |
| |
| // To create a codegen, we need to provide the statement - it specifies the |
| // computation we want to perform - and a list of placeholders and tensors |
| // used in the computation. The latter part is crucial since that's the only |
| // way the codegen could use to correlate symbols in the statement to actual |
| // data arrays that we will be passing when we will actually be performing |
| // the computation. |
| // |
| // Let's create a Simple IR Evaluator codegen for our computation: |
| SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); |
| |
| // We are using the simplest codegen and in it almost no work is done at the |
| // construction step. Real codegens such as CUDA and LLVM perform |
| // compilation during that stage so that when we're about to run the |
| // computation everything is ready. |
| |
| // Let's now create some inputs and run our computation with them: |
| std::vector<int> data_A(64 * 32, 3); // This will be the input A |
| std::vector<int> data_B(64 * 32, 5); // This will be the input B |
| std::vector<int> data_X(64 * 32, 0); // This will be used for the result |
| |
| // Now let's invoke our codegen to perform the computation on our data. We |
| // need to provide as many arguments as how many placeholders and tensors we |
| // passed at the codegen construction time. A position in these lists would |
| // define how real data arrays from the latter call (these arguments are |
| // referred to as 'CallArg's in our codebase) correspond to symbols |
| // (placeholders and tensors) used in the tensor expressions we constructed |
| // (these are referred to as 'BufferArg'). |
| // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A |
| // contains data for the placeholder A, data_B - for the placeholder B, and |
| // data_X would be used for contents of tensor X. |
| ir_eval(data_A, data_B, data_X); |
| |
| // Let's print one of the elements from each array to verify that the |
| // computation did happen: |
| std::cout << "A[10] = " << data_A[10] << std::endl |
| << "B[10] = " << data_B[10] << std::endl |
| << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; |
| // Prints: |
| // A[10] = 3 |
| // B[10] = 5 |
| // X[10] = A[10] + B[10] = 8 |
| } |
| |
| std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; |
| { |
| // This section requires a LLVM-enabled PyTorch build, so we have to use a |
| // guard: |
| #ifdef TORCH_ENABLE_LLVM |
| |
| // Often we would like to convert a TorchScript IR to TE rather than |
| // construct TE IR from scratch. NNC provides an API to perform such |
| // lowering: it takes a TorchScript graph and returns an object that can be |
| // used to invoke the generated kernel. |
| // This API is currently used by the TorchScript JIT fuser and can also be |
| // used ahead of time to pre-compile parts of a model. |
| // |
| // To get familiar with this API let's first start with defining a simple |
| // TorchScript graph: |
| const auto graph_string = R"IR( |
| graph(%A : Float(5, 3, strides=[3, 1], device=cpu), |
| %B : Float(5, 3, strides=[3, 1], device=cpu)): |
| %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) |
| %one : int = prim::Constant[value=1]() |
| %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) |
| %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) |
| return (%AAB_plus_B))IR"; |
| auto graph = std::make_shared<torch::jit::Graph>(); |
| parseIR(graph_string, &*graph); |
| |
| // This graph defines a simple computation of A*A*B + B where A and B are |
| // input 5x3 tensors. |
| |
| // To lower this TorchScript graph to TE, we just need to create a |
| // TensorExprKernel object. In its constructor it constructs the |
| // corresponding TE IR and compiles it for the given backend (in this |
| // example for CPU using LLVM compiler). |
| TensorExprKernel kernel(graph); |
| |
| // We can retrieve the generated TE stmt from the kernel object: |
| StmtPtr kernel_stmt = kernel.getCodeGenStmt(); |
| std::cout << "TE Stmt constructed from TorchScript: " << std::endl |
| << *kernel_stmt << std::endl; |
| // Prints: |
| // TE Stmt constructed from TorchScript: |
| // { |
| // for (const auto v : c10::irange(5)) { |
| // for (const auto _tail_tail : c10::irange(3)) { |
| // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * |
| // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + |
| // (tB[_tail_tail + 3 * v]); |
| // } |
| // } |
| // } |
| |
| // We can also examine generated LLVM IR and assembly code: |
| std::cout << "Generated LLVM IR: " << std::endl; |
| auto ir_str = kernel.getCodeText("ir"); |
| printLinesToFrom(ir_str, 15, 20); |
| // Prints: |
| // Generated LLVM IR: |
| // %9 = bitcast float* %2 to <8 x float>* |
| // %10 = load <8 x float>, <8 x float>* %9 ... |
| // %11 = bitcast float* %5 to <8 x float>* |
| // %12 = load <8 x float>, <8 x float>* %11 ... |
| // %13 = fmul <8 x float> %10, %12 |
| // %14 = fmul <8 x float> %10, %13 |
| |
| std::cout << "Generated assembly: " << std::endl; |
| auto asm_str = kernel.getCodeText("asm"); |
| printLinesToFrom(asm_str, 10, 15); |
| // Prints: |
| // Generated assembly: |
| // vmulps %ymm1, %ymm0, %ymm2 |
| // vfmadd213ps %ymm1, %ymm0, %ymm2 |
| // vmovups %ymm2, (%rax) |
| // vmovss 32(%rcx), %xmm0 |
| // vmovss 32(%rdx), %xmm1 |
| // vmulss %xmm1, %xmm0, %xmm2 |
| |
| // We can also execute the generated kernel: |
| auto A = |
| at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * |
| 2.0; |
| auto B = |
| at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * |
| 3.0; |
| std::vector<at::Tensor> inputs = {A, B}; |
| std::vector<torch::IValue> stack = torch::fmap<torch::IValue>(inputs); |
| kernel.run(stack); |
| auto R = stack[0].toTensor(); |
| |
| // Let's print one of the elements from the result tensor to verify that the |
| // computation did happen and was correct: |
| std::cout << "R[2][2] = " << R[2][2] << std::endl; |
| // Prints: |
| // R[2][2] = 15 |
| // [ CPUFloatType{} ] |
| #endif |
| } |
| return 0; |
| } |
| |
| void printLinesToFrom(const std::string& input_str, int from, int to) { |
| std::istringstream f(input_str); |
| std::string s; |
| int idx = 0; |
| while (getline(f, s)) { |
| if (idx > from) { |
| std::cout << s << "\n"; |
| } |
| if (idx++ > to) { |
| break; |
| } |
| } |
| } |