commit | b8386f5d72b1cecb7b99d1d512b7b282f3040a61 | [log] [tgz] |
---|---|---|
author | Supriya Rao <supriyar@fb.com> | Wed Jul 21 10:01:19 2021 -0700 |
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | Wed Jul 21 10:13:04 2021 -0700 |
tree | 848ede26a47a0a1e33d86a47f48543e2a76a166d | |
parent | afdca41bab7e2319de7be1c17fed1bb34f9807d4 [diff] |
[quant] Create FusedMovingAvgObsFakeQuantize for QAT (#61691) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61691 Create a new module for QAT that does a Fused MovingAvgMinMaxObserver and FakeQuantize operation The module currently only supports per-tensor quantization (affine/symmetric). Follow-up PR will add support for per-channel Results on running QAT with MobileNetV2 (Obs enabled/fake_quant enabled) Original FQ module PyTorchObserver {"type": "_", "metric": "qnnpack_fp_latency_ms", "unit": "ms", "value": "242.80261993408203"} PyTorchObserver {"type": "_", "metric": "qnnpack_qat0_latency_ms", "unit": "ms", "value": "505.7964324951172"} PyTorchObserver {"type": "_", "metric": "fbgemm_fp_latency_ms", "unit": "ms", "value": "235.80145835876465"} PyTorchObserver {"type": "_", "metric": "fbgemm_qat0_latency_ms", "unit": "ms", "value": "543.8144207000732"} Fused FakeQuant module (~50% improvement in latency) PyTorchObserver {"type": "_", "metric": "qnnpack_fp_latency_ms", "unit": "ms", "value": "232.1624755859375"} PyTorchObserver {"type": "_", "metric": "qnnpack_qat0_latency_ms", "unit": "ms", "value": "263.8866901397705"} PyTorchObserver {"type": "_", "metric": "fbgemm_fp_latency_ms", "unit": "ms", "value": "236.9832992553711"} PyTorchObserver {"type": "_", "metric": "fbgemm_qat0_latency_ms", "unit": "ms", "value": "292.1590805053711"} Individual module benchmark result (>5x improvement in latency) ===> Baseline FakeQuantize module ``` --------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls --------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::fake_quantize_per_tensor_affine 0.77% 1.210ms 4.92% 7.730ms 154.596us 718.528us 0.45% 9.543ms 190.862us 50 aten::fake_quantize_per_tensor_affine_cachemask 2.41% 3.792ms 4.15% 6.520ms 130.402us 8.825ms 5.58% 8.825ms 176.492us 50 aten::_aminmax 3.25% 5.105ms 4.43% 6.955ms 139.102us 8.193ms 5.18% 8.193ms 163.868us 50 aten::zeros_like 1.87% 2.939ms 6.95% 10.922ms 109.218us 5.992ms 3.79% 10.844ms 108.442us 100 aten::zeros 0.97% 1.527ms 3.11% 4.885ms 97.702us 2.383ms 1.51% 4.800ms 96.010us 50 aten::rsub 1.34% 2.106ms 2.94% 4.614ms 92.277us 2.063ms 1.30% 4.559ms 91.173us 50 aten::clamp 2.79% 4.381ms 5.42% 8.519ms 85.190us 5.385ms 3.41% 8.438ms 84.381us 100 aten::eq 11.70% 18.384ms 21.31% 33.479ms 83.280us 22.465ms 14.21% 33.310ms 82.861us 402 aten::ones 1.05% 1.656ms 2.57% 4.038ms 80.751us 2.494ms 1.58% 3.951ms 79.028us 50 aten::le 2.52% 3.955ms 4.84% 7.607ms 76.071us 4.998ms 3.16% 7.702ms 77.016us 100 aten::min 0.69% 1.087ms 2.32% 3.641ms 72.827us 1.017ms 0.64% 3.603ms 72.055us 50 aten::max 1.40% 2.195ms 4.62% 7.260ms 72.597us 2.008ms 1.27% 7.140ms 71.404us 100 aten::is_nonzero 2.68% 4.207ms 11.35% 17.829ms 71.033us 4.062ms 2.57% 17.225ms 68.625us 251 aten::detach 1.17% 1.831ms 3.65% 5.736ms 57.360us 1.680ms 1.06% 5.634ms 56.340us 100 aten::mul 3.36% 5.278ms 3.36% 5.278ms 53.862us 5.215ms 3.30% 5.215ms 53.216us 98 aten::div 3.42% 5.376ms 3.42% 5.376ms 53.759us 5.320ms 3.36% 5.320ms 53.196us 100 aten::sub 6.79% 10.672ms 6.79% 10.672ms 53.901us 10.504ms 6.64% 10.504ms 53.050us 198 aten::item 4.06% 6.380ms 12.02% 18.883ms 53.798us 6.127ms 3.87% 18.322ms 52.198us 351 aten::add 3.28% 5.147ms 3.28% 5.147ms 52.518us 5.113ms 3.23% 5.113ms 52.171us 98 aten::minimum 1.63% 2.555ms 1.63% 2.555ms 51.092us 2.585ms 1.64% 2.585ms 51.708us 50 aten::maximum 3.22% 5.065ms 3.22% 5.065ms 50.646us 5.133ms 3.25% 5.133ms 51.329us 100 aten::round 1.61% 2.529ms 1.61% 2.529ms 50.578us 2.528ms 1.60% 2.528ms 50.552us 50 aten::zero_ 1.99% 3.125ms 4.72% 7.422ms 49.481us 2.835ms 1.79% 7.269ms 48.462us 150 aten::copy_ 6.62% 10.394ms 6.62% 10.394ms 41.576us 10.252ms 6.48% 10.252ms 41.010us 250 detach 2.49% 3.905ms 2.49% 3.905ms 39.049us 3.954ms 2.50% 3.954ms 39.539us 100 aten::select 2.01% 3.154ms 2.47% 3.876ms 38.759us 3.866ms 2.44% 3.866ms 38.658us 100 aten::_local_scalar_dense 7.96% 12.503ms 7.96% 12.503ms 35.621us 12.195ms 7.71% 12.195ms 34.743us 351 aten::to 2.31% 3.625ms 4.16% 6.530ms 32.650us 4.320ms 2.73% 6.270ms 31.348us 200 aten::fill_ 3.70% 5.808ms 3.70% 5.808ms 29.039us 5.892ms 3.73% 5.892ms 29.459us 200 aten::as_strided 0.79% 1.244ms 0.79% 1.244ms 6.221us 0.000us 0.00% 0.000us 0.000us 200 aten::empty 3.55% 5.579ms 3.55% 5.579ms 11.137us 0.000us 0.00% 0.000us 0.000us 501 aten::resize_ 2.36% 3.712ms 2.36% 3.712ms 12.332us 0.000us 0.00% 0.000us 0.000us 301 aten::empty_like 1.45% 2.284ms 3.68% 5.776ms 28.878us 0.000us 0.00% 0.000us 0.000us 200 aten::empty_strided 2.80% 4.398ms 2.80% 4.398ms 17.592us 0.000us 0.00% 0.000us 0.000us 250 --------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 157.108ms Self CUDA time total: 158.122ms ``` ===> FusedFakeQuant ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ fb::fused_fake_quant 23.42% 6.408ms 100.00% 27.361ms 547.215us 7.887ms 27.20% 28.996ms 579.925us 50 aten::fake_quantize_per_tensor_affine 4.25% 1.162ms 27.65% 7.565ms 151.298us 686.176us 2.37% 10.217ms 204.336us 50 aten::_fake_quantize_per_tensor_affine_cachemask_ten... 14.11% 3.860ms 23.40% 6.403ms 128.068us 9.531ms 32.87% 9.531ms 190.612us 50 aten::_aminmax 20.57% 5.628ms 27.47% 7.515ms 150.305us 8.218ms 28.34% 8.218ms 164.367us 50 aten::item 3.65% 999.522us 10.27% 2.810ms 56.202us 931.904us 3.21% 2.674ms 53.481us 50 aten::_local_scalar_dense 6.62% 1.811ms 6.62% 1.811ms 36.212us 1.742ms 6.01% 1.742ms 34.843us 50 aten::empty 10.85% 2.969ms 10.85% 2.969ms 14.843us 0.000us 0.00% 0.000us 0.000us 200 aten::as_strided 1.92% 524.365us 1.92% 524.365us 5.244us 0.000us 0.00% 0.000us 0.000us 100 aten::empty_like 6.48% 1.774ms 14.62% 4.000ms 26.670us 0.000us 0.00% 0.000us 0.000us 150 aten::empty_strided 8.14% 2.226ms 8.14% 2.226ms 14.842us 0.000us 0.00% 0.000us 0.000us 150 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 27.361ms Self CUDA time total: 28.996ms ``` Test Plan: python test/test_quantization.py TestFusedObsFakeQuantModule Imported from OSS Reviewed By: vkuzo Differential Revision: D29706889 fbshipit-source-id: ae3f9fb1fc559920459bf6e8663e8299bf7d21e1
PyTorch is a Python package that provides two high-level features:
You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to extend PyTorch when needed.
System | 3.6 | 3.7 | 3.8 |
---|---|---|---|
Linux CPU | — | ||
Linux GPU | — | ||
Windows CPU / GPU | — | — | |
Linux (ppc64le) CPU | — | — | |
Linux (ppc64le) GPU | — | — | |
Linux (aarch64) CPU |
See also the ci.pytorch.org HUD.
At a granular level, PyTorch is a library that consists of the following components:
Component | Description |
---|---|
torch | a Tensor library like NumPy, with strong GPU support |
torch.autograd | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
torch.jit | a compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code |
torch.nn | a neural networks library deeply integrated with autograd designed for maximum flexibility |
torch.multiprocessing | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
torch.utils | DataLoader and other utility functions for convenience |
Usually, PyTorch is used either as:
Elaborating Further:
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount.
We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs such as slicing, indexing, math operations, linear algebra, reductions. And they are fast!
PyTorch has a unique way of building neural networks: using and replaying a tape recorder.
Most frameworks such as TensorFlow, Theano, Caffe, and CNTK have a static view of the world. One has to build a neural network and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch.
With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes from several research papers on this topic, as well as current and past work such as torch-autograd, autograd, Chainer, etc.
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research.
PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use NumPy / SciPy / scikit-learn etc. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as Cython and Numba. Our goal is to not reinvent the wheel where appropriate.
PyTorch is designed to be intuitive, linear in thought, and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. When you drop into a debugger or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.
PyTorch has minimal framework overhead. We integrate acceleration libraries such as Intel MKL and NVIDIA (cuDNN, NCCL) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends (TH, THC, THNN, THCUNN) are mature and have been tested for years.
Hence, PyTorch is quite fast – whether you run small or large neural networks.
The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. We've written custom memory allocators for the GPU to make sure that your deep learning models are maximally memory efficient. This enables you to train bigger deep learning models than before.
Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward and with minimal abstractions.
You can write new neural network layers in Python using the torch API or your favorite NumPy-based libraries such as SciPy.
If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. No wrapper code needs to be written. You can see a tutorial here and an example here.
Commands to install from binaries via Conda or pip wheels are on our website: https://pytorch.org
Python wheels for NVIDIA's Jetson Nano, Jetson TX2, and Jetson AGX Xavier are available via the following URLs:
They require JetPack 4.2 and above, and @dusty-nv maintains them
If you are installing from source, you will need Python 3.6.2 or later and a C++14 compiler. Also, we highly recommend installing an Anaconda environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro.
Once you have Anaconda installed, here are the instructions.
If you want to compile with CUDA support, install
If you want to disable CUDA support, export environment variable USE_CUDA=0
. Other potentially useful environment variables may be found in setup.py
.
If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to install PyTorch for Jetson Nano are available here
If you want to compile with ROCm support, install
If you want to disable ROCm support, export environment variable USE_ROCM=0
. Other potentially useful environment variables may be found in setup.py
.
Common
conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
On Linux
# CUDA only: Add LAPACK support for the GPU if needed conda install -c pytorch magma-cuda110 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo
On MacOS
# Add these packages if torch.distributed is needed conda install pkg-config libuv
On Windows
# Add these packages if torch.distributed is needed. # Distributed package support on Windows is a prototype feature and is subject to changes. conda install -c conda-forge libuv=1.39
git clone --recursive https://github.com/pytorch/pytorch cd pytorch # if you are updating an existing checkout git submodule sync git submodule update --init --recursive --jobs 0
On Linux
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py install
Note that if you are compiling for ROCm, you must run this command first:
python tools/amd_build/build_amd.py
Note that if you are using Anaconda, you may experience an error caused by the linker:
build/temp.linux-x86_64-3.7/torch/csrc/stub.o: file not recognized: file format not recognized collect2: error: ld returned 1 exit status error: command 'g++' failed with exit status 1
This is caused by ld
from Conda environment shadowing the system ld
. You should use a newer version of Python that fixes this issue. The recommended Python version is 3.6.10+, 3.7.6+ and 3.8.1+.
On macOS
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
Each CUDA version only supports one particular XCode version. The following combinations have been reported to work with PyTorch.
CUDA version | XCode version |
---|---|
10.0 | XCode 9.4 |
10.1 | XCode 10.1 |
On Windows
Choose Correct Visual Studio Version.
Sometimes there are regressions in new versions of Visual Studio, so it‘s best to use the same Visual Studio Version 16.8.5 as Pytorch CI’s. You can use Visual Studio Enterprise, Professional or Community though PyTorch CI uses Visual Studio BuildTools.
If you want to build legacy python code, please refer to Building on legacy code and CUDA
Build with CPU
It's fairly easy to build with CPU.
Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking CMAKE_INCLUDE_PATH
and LIB
. The instruction here is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used.
Build with CUDA
NVTX is needed to build Pytorch with CUDA. NVTX is a part of CUDA distributive, where it is called “Nsight Compute”. To install it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox. Make sure that CUDA with Nsight Compute is installed after Visual Studio.
Currently, VS 2017 / 2019, and Ninja are supported as the generator of CMake. If ninja.exe
is detected in PATH
, then Ninja will be used as the default generator, otherwise, it will use VS 2017 / 2019.
If Ninja is selected as the generator, the latest MSVC will get selected as the underlying toolchain.
Additional libraries such as Magma, oneDNN, a.k.a MKLDNN or DNNL, and Sccache are often needed. Please refer to the installation-helper to install them.
You can refer to the build_pytorch.bat script for some other environment variables configurations
cmd :: [Optional] If you want to build with the VS 2017 generator for old CUDA and PyTorch, please change the value in the next line to `Visual Studio 15 2017`. :: Note: This value is useless if Ninja is detected. However, you can force that by using `set USE_NINJA=OFF`. set CMAKE_GENERATOR=Visual Studio 16 2019 :: Read the content in the previous section carefully before you proceed. :: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. :: "Visual Studio 2019 Developer Command Prompt" will be run automatically. :: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator. set CMAKE_GENERATOR_TOOLSET_VERSION=14.27 set DISTUTILS_USE_SDK=1 for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION% :: [Optional] If you want to override the CUDA host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe python setup.py install
You can adjust the configuration of cmake variables optionally (without building first), by doing the following. For example, adjusting the pre-detected directories for CuDNN or BLAS can be done with such a step.
On Linux
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py build --cmake-only ccmake build # or cmake-gui build
On macOS
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build --cmake-only ccmake build # or cmake-gui build
You can also pull a pre-built docker image from Docker Hub and run with docker v19.03+
docker run --gpus all --rm -ti --ipc=host pytorch/pytorch:latest
Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host
or --shm-size
command line options to nvidia-docker run
.
NOTE: Must be built with a docker version > 18.06
The Dockerfile
is supplied to build images with CUDA 11.1 support and cuDNN v8. You can pass PYTHON_VERSION=x.y
make variable to specify which Python version is to be used by Miniconda, or leave it unset to use the default.
make -f docker.Makefile # images are tagged as docker.io/${your_docker_username}/pytorch
To build documentation in various formats, you will need Sphinx and the readthedocs theme.
cd docs/ pip install -r requirements.txt
You can then build the documentation by running make <format>
from the docs/
folder. Run make
to get a list of all available output formats.
If you get a katex error run npm install katex
. If it persists, try npm install -g katex
Installation instructions and binaries for previous PyTorch versions may be found on Our Website.
Three-pointers to get you started:
PyTorch has a 90-day release cycle (major releases). Please let us know if you encounter a bug by filing an issue.
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of.
To learn more about making a contribution to Pytorch, please see our Contribution page.
PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by Adam Paszke, Sam Gross, Soumith Chintala and Gregory Chanan with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.
Note: This project is unrelated to hughperkins/pytorch with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.
PyTorch has a BSD-style license, as found in the LICENSE file.