commit | 0e582fbfcc8cab66c0265d3fe326e3dc505855d1 | [log] [tgz] |
---|---|---|
author | jjsjann123 <jiej@nvidia.com> | Wed Sep 21 15:03:10 2022 -0700 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Fri Sep 23 20:29:48 2022 +0000 |
tree | ae012c71946ceeada7ec02ee4ecbfd994389adfa | |
parent | 52a8be523ce682ce26dd793a4154b668b1f37703 [diff] |
[NVFuser] Upstream push 0907 (#84626) Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Codegen changes include: - codegen improvement: i. improved view support on pointwise and transpose scheduler ii. grouped grid welford added for better outer-norm grid persistence in normalization - misc: i. new composite ops added: variance_mean , arange, ii. fixes misaligned address for transpose scheduler iii. refactor on separation of compilation API from execution API to prepare us for async compilation iv. double type support on expression evaluator v. PYTORCH_NVFUSER_DUMP refactor to save PTX and CUBIN Commits that's in this PR from the devel branch: ``` 89330aa23aa804340b2406ab58899d816e3dc3d2 Tensor factories must set the output shape as its input (#1939) b2fd01ea9346712c6d6f623ca6addbc4888d008e arange support (#1933) 56c00fd3922dad7dfc57351ad7d780f0f2f8e4ed Double support on all expression evaluators (#1937) 371f28223e57fe3f6b5e50a0a45177e6a5c0785c Improve trivial reduction merge support (#1931) 1d0c26790e5647920b40d419d26815bbe310b3a6 Test `rand` in a fusion with zero tensor input (#1932) 0dab160fb2177d178eef3148c6a529e0855009e9 Fix softmax bwd sizes. (#1890) ef98f360f6d3e3e1cc662ecb65202d88150f128d Fix a bug (#1936) 63132a0c56508c550084b07fb76a3df865102d00 Propagate permissive mapping information into indexing pass (#1929) b4ac2c88d78078ee4d8b21c4fc51645b5710a282 Map IterationDomains through view operations. (#1919) c0a187a7619d7cf9dc920294e15461791e8d6d4d do not use deprecated functions (#1935) 88de85e758c5e4afb7b6e746573c0d9a53b4cea7 Upstream cherry pick fixes 0811 (#1934) b247dcf7c57dc6ac3f7a799b0a6beb7770536a74 Separate kernel compilation API from kernel execution API (#1914) b34e3b93ee1a8030730c14af3995dd95665af07d Fix `ir_utils::hasBlockSync` + misc fixes in transpose scheduler (#1924) 14a53e6707f43bf760494c238a46386d69830822 Nullary RNGOp (#1892) 3c3c89e638f5172cafb0761f22bacd1fd695eec3 Misc fixes/tuning for transpose scheduler (#1912) 20cf109c8b44d48f61977e35bae94368985144ac Grouped grid welford (#1921) 6cf7eb024c9e53c358cbe56597e117bad56efefd Transpose scheduler small dim sizes better support (#1910) 9341ea9a5bf42f9b14ccad0c94edbc79fc5bb552 Disabled ViewPersistentShmoo sizes that results in NAN (#1922) 057237f66deeea816bb943d802a97c1b7e4414ab Fix CUDA driver error: misaligned address for transpose scheduler (#1918) 3fb3d80339e4f794767a53eb8fdd61e64cf404a2 Add variance_mean function using Welford (#1907) 98febf6aa3b8c6fe4fdfb2864cda9e5d30089262 Remove DisableOption::UnrollWithRng (#1913) ee8ef33a5591b534cf587d347af11e48ba7a15d4 Minor fix for the debug interface of using PTX directly (#1917) 6e8f953351f9dabfd1f991d8431cecb6c2ce684d Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN (#1916) 5eefa9a72385f6a4b145680a9dcc52d7e8293763 dopt is only available since nvrtc 11.7 (#1915) 2ec8fc711eafc72451eebf0f5e2a98a38bf3f6ef Kill computeAtBetween (#1911) d0d106a1d9af118d71673173674e875be35d259d Improve view support on pointwise and transpose scheduler (#1906) e71e1ecefe67219846070590bbed54bbc7416b79 Fix name clash of RNG with shared memory (#1904) 3381793a253689abf224febc73fd3fe2a0dbc921 Fix mutator and sameAs for expanded IterDomain (#1902) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D39324552](https://our.internmc.facebook.com/intern/diff/D39324552) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84626 Approved by: https://github.com/malfet
PyTorch is a Python package that provides two high-level features:
You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to extend PyTorch when needed.
Our trunk health (Continuous Integration signals) can be found at hud.pytorch.org.
At a granular level, PyTorch is a library that consists of the following components:
Component | Description |
---|---|
torch | A Tensor library like NumPy, with strong GPU support |
torch.autograd | A tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
torch.jit | A compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code |
torch.nn | A neural networks library deeply integrated with autograd designed for maximum flexibility |
torch.multiprocessing | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
torch.utils | DataLoader and other utility functions for convenience |
Usually, PyTorch is used either as:
Elaborating Further:
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount.
We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs such as slicing, indexing, mathematical operations, linear algebra, reductions. And they are fast!
PyTorch has a unique way of building neural networks: using and replaying a tape recorder.
Most frameworks such as TensorFlow, Theano, Caffe, and CNTK have a static view of the world. One has to build a neural network and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch.
With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes from several research papers on this topic, as well as current and past work such as torch-autograd, autograd, Chainer, etc.
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research.
PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use NumPy / SciPy / scikit-learn etc. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as Cython and Numba. Our goal is to not reinvent the wheel where appropriate.
PyTorch is designed to be intuitive, linear in thought, and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. When you drop into a debugger or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.
PyTorch has minimal framework overhead. We integrate acceleration libraries such as Intel MKL and NVIDIA (cuDNN, NCCL) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends are mature and have been tested for years.
Hence, PyTorch is quite fast – whether you run small or large neural networks.
The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. We've written custom memory allocators for the GPU to make sure that your deep learning models are maximally memory efficient. This enables you to train bigger deep learning models than before.
Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward and with minimal abstractions.
You can write new neural network layers in Python using the torch API or your favorite NumPy-based libraries such as SciPy.
If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. No wrapper code needs to be written. You can see a tutorial here and an example here.
Commands to install binaries via Conda or pip wheels are on our website: https://pytorch.org/get-started/locally/
Python wheels for NVIDIA's Jetson Nano, Jetson TX1/TX2, Jetson Xavier NX/AGX, and Jetson AGX Orin are provided here and the L4T container is published here
They require JetPack 4.2 and above, and @dusty-nv and @ptrblck are maintaining them.
If you are installing from source, you will need:
We highly recommend installing an Anaconda environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro.
If you want to compile with CUDA support, install the following (note that CUDA is not supported on macOS)
Note: You could refer to the cuDNN Support Matrix for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardware
If you want to disable CUDA support, export the environment variable USE_CUDA=0
. Other potentially useful environment variables may be found in setup.py
.
If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to install PyTorch for Jetson Nano are available here
If you want to compile with ROCm support, install
If you want to disable ROCm support, export the environment variable USE_ROCM=0
. Other potentially useful environment variables may be found in setup.py
.
Common
conda install astunparse numpy ninja pyyaml setuptools cmake cffi typing_extensions future six requests dataclasses
On Linux
conda install mkl mkl-include # CUDA only: Add LAPACK support for the GPU if needed conda install -c pytorch magma-cuda110 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo
On MacOS
# Add this package on intel x86 processor machines only conda install mkl mkl-include # Add these packages if torch.distributed is needed conda install pkg-config libuv
On Windows
conda install mkl mkl-include # Add these packages if torch.distributed is needed. # Distributed package support on Windows is a prototype feature and is subject to changes. conda install -c conda-forge libuv=1.39
git clone --recursive https://github.com/pytorch/pytorch cd pytorch # if you are updating an existing checkout git submodule sync git submodule update --init --recursive --jobs 0
On Linux
If you're compiling for AMD ROCm then first run this command:
# Only run this if you're compiling for ROCm python tools/amd_build/build_amd.py
Install PyTorch
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py install
Note that if you are using Anaconda, you may experience an error caused by the linker:
build/temp.linux-x86_64-3.7/torch/csrc/stub.o: file not recognized: file format not recognized collect2: error: ld returned 1 exit status error: command 'g++' failed with exit status 1
This is caused by ld
from the Conda environment shadowing the system ld
. You should use a newer version of Python that fixes this issue. The recommended Python version is 3.7.6+ and 3.8.1+.
On macOS
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
On Windows
Choose Correct Visual Studio Version.
Sometimes there are regressions in new versions of Visual Studio, so it‘s best to use the same Visual Studio Version 16.8.5 as Pytorch CI’s.
PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, Professional, or Community Editions. You can also install the build tools from https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools do not come with Visual Studio Code by default.
If you want to build legacy python code, please refer to Building on legacy code and CUDA
CPU-only builds
In this mode PyTorch computations will run on your CPU, not your GPU
conda activate python setup.py install
Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking CMAKE_INCLUDE_PATH
and LIB
. The instruction here is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used.
CUDA based build
In this mode PyTorch computations will leverage your GPU via CUDA for faster number crunching
NVTX is needed to build Pytorch with CUDA. NVTX is a part of CUDA distributive, where it is called “Nsight Compute”. To install it onto an already installed CUDA run CUDA installation once again and check the corresponding checkbox. Make sure that CUDA with Nsight Compute is installed after Visual Studio.
Currently, VS 2017 / 2019, and Ninja are supported as the generator of CMake. If ninja.exe
is detected in PATH
, then Ninja will be used as the default generator, otherwise, it will use VS 2017 / 2019.
If Ninja is selected as the generator, the latest MSVC will get selected as the underlying toolchain.
Additional libraries such as Magma, oneDNN, a.k.a MKLDNN or DNNL, and Sccache are often needed. Please refer to the installation-helper to install them.
You can refer to the build_pytorch.bat script for some other environment variables configurations
cmd :: Set the environment variables after you have downloaded and unzipped the mkl package, :: else CMake would throw an error as `Could NOT find OpenMP`. set CMAKE_INCLUDE_PATH={Your directory}\mkl\include set LIB={Your directory}\mkl\lib;%LIB% :: Read the content in the previous section carefully before you proceed. :: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. :: "Visual Studio 2019 Developer Command Prompt" will be run automatically. :: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator. set CMAKE_GENERATOR_TOOLSET_VERSION=14.27 set DISTUTILS_USE_SDK=1 for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,17^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION% :: [Optional] If you want to override the CUDA host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe python setup.py install
You can adjust the configuration of cmake variables optionally (without building first), by doing the following. For example, adjusting the pre-detected directories for CuDNN or BLAS can be done with such a step.
On Linux
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py build --cmake-only ccmake build # or cmake-gui build
On macOS
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build --cmake-only ccmake build # or cmake-gui build
You can also pull a pre-built docker image from Docker Hub and run with docker v19.03+
docker run --gpus all --rm -ti --ipc=host pytorch/pytorch:latest
Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host
or --shm-size
command line options to nvidia-docker run
.
NOTE: Must be built with a docker version > 18.06
The Dockerfile
is supplied to build images with CUDA 11.1 support and cuDNN v8. You can pass PYTHON_VERSION=x.y
make variable to specify which Python version is to be used by Miniconda, or leave it unset to use the default.
make -f docker.Makefile # images are tagged as docker.io/${your_docker_username}/pytorch
To build documentation in various formats, you will need Sphinx and the readthedocs theme.
cd docs/ pip install -r requirements.txt
You can then build the documentation by running make <format>
from the docs/
folder. Run make
to get a list of all available output formats.
If you get a katex error run npm install katex
. If it persists, try npm install -g katex
Note: if you installed
nodejs
with a different package manager (e.g.,conda
) thennpm
will probably install a version ofkatex
that is not compatible with your version ofnodejs
and doc builds will fail. A combination of versions that is known to work isnode@6.13.1
andkatex@0.13.18
. To install the latter withnpm
you can runnpm install -g katex@0.13.18
Installation instructions and binaries for previous PyTorch versions may be found on our website.
Three-pointers to get you started:
PyTorch has a 90-day release cycle (major releases). Please let us know if you encounter a bug by filing an issue.
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of.
To learn more about making a contribution to Pytorch, please see our Contribution page.
PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by Adam Paszke, Sam Gross, Soumith Chintala and Gregory Chanan with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.
Note: This project is unrelated to hughperkins/pytorch with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.
PyTorch has a BSD-style license, as found in the LICENSE file.