commit | 50ef099ad0ea307ecab881c8bada6c08f36723a1 | [log] [tgz] |
---|---|---|
author | Alnis Murtovi <murtovi@meta.com> | Mon Jul 15 23:04:06 2024 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Mon Jul 15 23:04:06 2024 +0000 |
tree | 35a98fc75d8118e6c15271cff4f49d9faa7213a3 | |
parent | 9a5204dc2db965e62ebbb10951054dc25d16a9e9 [diff] |
Learn a heuristic to decide whether to pad before mm (#128643) This PR introduces AutoHeuristic, a framework to collect results from autotuning, learn a heuristic as a machine learning model (a regression tree), and then ship the learned heuristic by generating the regression tree to code. The heuristics have been learned on artificial/random data that has been collected with the `gen_data_pad_mm.py` script. The `gen_pad_mm_a100.sh` scripts can then be used to learn a heuristic and generate it to code. The best model is decided by doing a grid search over various values for `max_depth` and `min_samples_leaf` and choosing the model with the highest number of correct predicitons on the validation set. The heuristic can return "unsure" which means that it is not sure which choice is the best choice and as a result autotuning will happen. On A100 only tensors where each dimension is >= 512 are considered. For smaller tensors the heuristics that I learned returned "unsure" too often. The results for randomly generated data and huggingface look as follows: `max_wrong_speedup` is max(`wrong_speedups`) where `wrong_speedups` contains all the speedups one could have achieved for those examples where the heuristic made a wrong choice, i.e. a `max_wrong_speedup` of 1.37 means that the heuristic selected a choice, but the other choice would have been 1.37x faster. `gman_wrong_speedup` is the geomean of `wrong_speedups`. The heuristic is learned as a regression tree, that returns higher values for better choices. The threshold decides how much better the better choice has to be for it to be returned, i.e. on A100 if the better choice is less than 1.702530x better than the other choice, "unsure" will be returned. This threshold is determined using the validation set. A100 ``` max_depth min_samples_leaf dataset correct wrong unsure total max_wrong_speedup gman_wrong_speedup threshold 15 5.0 10 train 2730 4 3023 5757 1.372220 1.193873 1.702530 16 5.0 10 val 878 0 1042 1920 NaN NaN 1.702530 17 5.0 10 test 925 2 993 1920 1.741708 1.354954 1.702530 18 5.0 10 hf-train 14 0 22 36 NaN NaN 1.702530 19 5.0 10 hf-inf 7 0 1 8 NaN NaN 1.702530 ``` The numbers for huggingface only include tensors where each dim is >=512. If all tensors would have been included there would have been the following number of matmuls, where at least one dimension is unaligned: A100 hf-train: 60 A100 hf-inf: 10 ## Results on running huggingface locally This only includes models where the learned heuristic made at least one decision. For the examples here, it takes around 0.25-0.3 seconds to perform autotuning for the padded and unpadded version, so each decision that the heuristic makes saves around 0.25-0.3 seconds. #pad_mm_autotuning is the number of times autotuning happened in pad_mm and #heuristic_made_decision is the number of times the heuristic made a decision (i.e. it didn't return "unsure"). I ran huggingface locally, each model 5 times and took the median speedup and compilation_latency. Results on huggingface training ``` name speedup_heuristic speedup_baseline speedup_diff compilation_latency_heuristic compilation_latency_baseline compilation_latency_diff comp_latency_reduction% #pad_mm_autotuning #heuristic_made_decision BartForCausalLM 1.19 (+/- 0.00) 1.19 (+/- 0.00) -0.00 40.33 (+/- 1.13) 40.95 (+/- 0.78) -0.62 1.52 3 2 BartForConditionalGeneration 1.53 (+/- 0.06) 1.47 (+/- 0.05) 0.06 81.93 (+/- 5.20) 82.23 (+/- 1.92) -0.30 0.36 3 1 BlenderbotSmallForCausalLM 1.86 (+/- 0.04) 1.86 (+/- 0.00) 0.00 36.76 (+/- 0.49) 37.62 (+/- 1.33) -0.87 2.31 3 2 CamemBert 2.36 (+/- 0.01) 2.35 (+/- 0.01) 0.01 97.60 (+/- 1.91) 98.69 (+/- 1.35) -1.09 1.11 2 1 DistillGPT2 2.57 (+/- 0.01) 2.57 (+/- 0.01) 0.00 57.33 (+/- 0.77) 58.26 (+/- 1.41) -0.93 1.59 3 2 PLBartForCausalLM 2.07 (+/- 0.01) 2.06 (+/- 0.01) 0.01 32.54 (+/- 0.83) 34.65 (+/- 0.71) -2.11 6.10 3 2 PLBartForConditionalGeneration 1.87 (+/- 0.00) 1.88 (+/- 0.00) -0.01 58.45 (+/- 1.24) 58.95 (+/- 1.92) -0.50 0.85 3 1 RobertaForCausalLM 2.39 (+/- 0.01) 2.40 (+/- 0.01) -0.01 97.38 (+/- 1.52) 97.69 (+/- 1.18) -0.31 0.32 2 1 TrOCRForCausalLM 1.70 (+/- 0.00) 1.70 (+/- 0.00) -0.00 44.79 (+/- 1.33) 45.25 (+/- 1.08) -0.46 1.01 3 2 Mean difference in speedup: 0.01 Mean compilation latency saved: -0.80s Mean compilation latency reduction: 1.68% ``` Results on huggingface inference ``` name speedup_heuristic speedup_baseline speedup_diff compilation_latency_heuristic compilation_latency_baseline compilation_latency_diff comp_latency_reduction% #pad_mm_autotuning #heuristic_made_decision BartForCausalLM 1.11 (+/- 0.00) 1.11 (+/- 0.00) 0.00 19.02 (+/- 0.28) 19.40 (+/- 0.35) -0.38 1.95 3 2 BartForConditionalGeneration 1.26 (+/- 0.01) 1.23 (+/- 0.03) 0.03 36.84 (+/- 0.40) 36.55 (+/- 0.75) 0.30 -0.81 3 1 BlenderbotSmallForCausalLM 1.87 (+/- 0.02) 1.87 (+/- 0.01) 0.00 17.53 (+/- 0.31) 18.03 (+/- 0.43) -0.49 2.74 3 2 DistillGPT2 2.50 (+/- 0.02) 2.50 (+/- 0.01) 0.00 16.16 (+/- 0.29) 16.40 (+/- 0.18) -0.24 1.46 3 2 PLBartForCausalLM 1.93 (+/- 0.01) 1.94 (+/- 0.01) -0.00 15.30 (+/- 0.22) 16.01 (+/- 0.71) -0.71 4.43 3 2 PLBartForConditionalGeneration 1.98 (+/- 0.01) 1.98 (+/- 0.01) 0.00 25.90 (+/- 0.32) 26.58 (+/- 0.62) -0.67 2.53 3 1 TrOCRForCausalLM 1.61 (+/- 0.00) 1.62 (+/- 0.00) -0.01 21.38 (+/- 0.37) 21.85 (+/- 0.16) -0.47 2.16 3 2 Mean difference in speedup: 0.00 Mean compilation latency saved: -0.38s Mean compilation latency reduction: 2.07% ``` For now, the heuristic can only be applied to decide whether to pad for mm. One could also learn heuristics for bmm and addmm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128643 Approved by: https://github.com/Chillee, https://github.com/eellison
PyTorch is a Python package that provides two high-level features:
You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to extend PyTorch when needed.
Our trunk health (Continuous Integration signals) can be found at hud.pytorch.org.
At a granular level, PyTorch is a library that consists of the following components:
Component | Description |
---|---|
torch | A Tensor library like NumPy, with strong GPU support |
torch.autograd | A tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
torch.jit | A compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code |
torch.nn | A neural networks library deeply integrated with autograd designed for maximum flexibility |
torch.multiprocessing | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
torch.utils | DataLoader and other utility functions for convenience |
Usually, PyTorch is used either as:
Elaborating Further:
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount.
We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs such as slicing, indexing, mathematical operations, linear algebra, reductions. And they are fast!
PyTorch has a unique way of building neural networks: using and replaying a tape recorder.
Most frameworks such as TensorFlow, Theano, Caffe, and CNTK have a static view of the world. One has to build a neural network and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch.
With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes from several research papers on this topic, as well as current and past work such as torch-autograd, autograd, Chainer, etc.
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research.
PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use NumPy / SciPy / scikit-learn etc. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as Cython and Numba. Our goal is to not reinvent the wheel where appropriate.
PyTorch is designed to be intuitive, linear in thought, and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. When you drop into a debugger or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.
PyTorch has minimal framework overhead. We integrate acceleration libraries such as Intel MKL and NVIDIA (cuDNN, NCCL) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends are mature and have been tested for years.
Hence, PyTorch is quite fast — whether you run small or large neural networks.
The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. We've written custom memory allocators for the GPU to make sure that your deep learning models are maximally memory efficient. This enables you to train bigger deep learning models than before.
Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward and with minimal abstractions.
You can write new neural network layers in Python using the torch API or your favorite NumPy-based libraries such as SciPy.
If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. No wrapper code needs to be written. You can see a tutorial here and an example here.
Commands to install binaries via Conda or pip wheels are on our website: https://pytorch.org/get-started/locally/
Python wheels for NVIDIA's Jetson Nano, Jetson TX1/TX2, Jetson Xavier NX/AGX, and Jetson AGX Orin are provided here and the L4T container is published here
They require JetPack 4.2 and above, and @dusty-nv and @ptrblck are maintaining them.
If you are installing from source, you will need:
We highly recommend installing an Anaconda environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro.
If you want to compile with CUDA support, select a supported version of CUDA from our support matrix, then install the following:
Note: You could refer to the cuDNN Support Matrix for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardware
If you want to disable CUDA support, export the environment variable USE_CUDA=0
. Other potentially useful environment variables may be found in setup.py
.
If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to install PyTorch for Jetson Nano are available here
If you want to compile with ROCm support, install
If you want to disable ROCm support, export the environment variable USE_ROCM=0
. Other potentially useful environment variables may be found in setup.py
.
If you want to compile with Intel GPU support, follow these
If you want to disable Intel GPU support, export the environment variable USE_XPU=0
. Other potentially useful environment variables may be found in setup.py
.
Common
conda install cmake ninja # Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below pip install -r requirements.txt
On Linux
conda install intel::mkl-static intel::mkl-include # CUDA only: Add LAPACK support for the GPU if needed conda install -c pytorch magma-cuda121 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo # (optional) If using torch.compile with inductor/triton, install the matching version of triton # Run from the pytorch directory after cloning # For Intel GPU support, please explicitly `export USE_XPU=1` before running command. make triton
On MacOS
# Add this package on intel x86 processor machines only conda install intel::mkl-static intel::mkl-include # Add these packages if torch.distributed is needed conda install pkg-config libuv
On Windows
conda install intel::mkl-static intel::mkl-include # Add these packages if torch.distributed is needed. # Distributed package support on Windows is a prototype feature and is subject to changes. conda install -c conda-forge libuv=1.39
git clone --recursive https://github.com/pytorch/pytorch cd pytorch # if you are updating an existing checkout git submodule sync git submodule update --init --recursive
On Linux
If you would like to compile PyTorch with new C++ ABI enabled, then first run this command:
export _GLIBCXX_USE_CXX11_ABI=1
If you're compiling for AMD ROCm then first run this command:
# Only run this if you're compiling for ROCm python tools/amd_build/build_amd.py
Install PyTorch
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py develop
Aside: If you are using Anaconda, you may experience an error caused by the linker:
build/temp.linux-x86_64-3.7/torch/csrc/stub.o: file not recognized: file format not recognized collect2: error: ld returned 1 exit status error: command 'g++' failed with exit status 1This is caused by
ld
from the Conda environment shadowing the systemld
. You should use a newer version of Python that fixes this issue. The recommended Python version is 3.8.1+.
On macOS
python3 setup.py develop
On Windows
Choose Correct Visual Studio Version.
PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, Professional, or Community Editions. You can also install the build tools from https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools do not come with Visual Studio Code by default.
If you want to build legacy python code, please refer to Building on legacy code and CUDA
CPU-only builds
In this mode PyTorch computations will run on your CPU, not your GPU
conda activate python setup.py develop
Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking CMAKE_INCLUDE_PATH
and LIB
. The instruction here is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used.
CUDA based build
In this mode PyTorch computations will leverage your GPU via CUDA for faster number crunching
NVTX is needed to build Pytorch with CUDA. NVTX is a part of CUDA distributive, where it is called “Nsight Compute”. To install it onto an already installed CUDA run CUDA installation once again and check the corresponding checkbox. Make sure that CUDA with Nsight Compute is installed after Visual Studio.
Currently, VS 2017 / 2019, and Ninja are supported as the generator of CMake. If ninja.exe
is detected in PATH
, then Ninja will be used as the default generator, otherwise, it will use VS 2017 / 2019.
If Ninja is selected as the generator, the latest MSVC will get selected as the underlying toolchain.
Additional libraries such as Magma, oneDNN, a.k.a. MKLDNN or DNNL, and Sccache are often needed. Please refer to the installation-helper to install them.
You can refer to the build_pytorch.bat script for some other environment variables configurations
cmd :: Set the environment variables after you have downloaded and unzipped the mkl package, :: else CMake would throw an error as `Could NOT find OpenMP`. set CMAKE_INCLUDE_PATH={Your directory}\mkl\include set LIB={Your directory}\mkl\lib;%LIB% :: Read the content in the previous section carefully before you proceed. :: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. :: "Visual Studio 2019 Developer Command Prompt" will be run automatically. :: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator. set CMAKE_GENERATOR_TOOLSET_VERSION=14.27 set DISTUTILS_USE_SDK=1 for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,17^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION% :: [Optional] If you want to override the CUDA host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe python setup.py develop
You can adjust the configuration of cmake variables optionally (without building first), by doing the following. For example, adjusting the pre-detected directories for CuDNN or BLAS can be done with such a step.
On Linux
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py build --cmake-only ccmake build # or cmake-gui build
On macOS
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build --cmake-only ccmake build # or cmake-gui build
You can also pull a pre-built docker image from Docker Hub and run with docker v19.03+
docker run --gpus all --rm -ti --ipc=host pytorch/pytorch:latest
Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host
or --shm-size
command line options to nvidia-docker run
.
NOTE: Must be built with a docker version > 18.06
The Dockerfile
is supplied to build images with CUDA 11.1 support and cuDNN v8. You can pass PYTHON_VERSION=x.y
make variable to specify which Python version is to be used by Miniconda, or leave it unset to use the default.
make -f docker.Makefile # images are tagged as docker.io/${your_docker_username}/pytorch
You can also pass the CMAKE_VARS="..."
environment variable to specify additional CMake variables to be passed to CMake during the build. See setup.py for the list of available variables.
make -f docker.Makefile
To build documentation in various formats, you will need Sphinx and the readthedocs theme.
cd docs/ pip install -r requirements.txt
You can then build the documentation by running make <format>
from the docs/
folder. Run make
to get a list of all available output formats.
If you get a katex error run npm install katex
. If it persists, try npm install -g katex
Note: if you installed
nodejs
with a different package manager (e.g.,conda
) thennpm
will probably install a version ofkatex
that is not compatible with your version ofnodejs
and doc builds will fail. A combination of versions that is known to work isnode@6.13.1
andkatex@0.13.18
. To install the latter withnpm
you can runnpm install -g katex@0.13.18
Installation instructions and binaries for previous PyTorch versions may be found on our website.
Three-pointers to get you started:
Typically, PyTorch has three minor releases a year. Please let us know if you encounter a bug by filing an issue.
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of.
To learn more about making a contribution to Pytorch, please see our Contribution page. For more information about PyTorch releases, see Release page.
PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by Soumith Chintala, Gregory Chanan, Dmytro Dzhulgakov, Edward Yang, and Nikita Shulga with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.
Note: This project is unrelated to hughperkins/pytorch with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.
PyTorch has a BSD-style license, as found in the LICENSE file.