commit | 67414714e51b96f4820a842aa6727f954a29b66e | [log] [tgz] |
---|---|---|
author | Syed Tousif Ahmed <syed.ahmed.emails@gmail.com> | Mon May 13 09:35:30 2019 -0700 |
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | Mon May 13 09:38:28 2019 -0700 |
tree | 11a7dd47fb2297b02f68de8fcb999b6355926b7d | |
parent | 5f7ef09f5777d35eb1f08e6c675eb2ead41d6165 [diff] |
Move THCTensor_(uniform) to ATen (#20292) Summary: As a first step for this plan: https://github.com/pytorch/pytorch/issues/19508#issuecomment-485178192, this PR moves `THCTensor_(uniform)` to ATen. Major changes are: - `uniform_` cuda kernel now utilizes a philox generator. - the kernel also utilizes TensorIterator - the kernel uses a grid-stride loop to achieve peak effective bandwidth - Since the engine has changed from `curandStateMTGP32` to `curandStatePhilox4_32_10`, the randoms generated now will be different. - Here is the diff showing codegen changes: https://gist.github.com/syed-ahmed/4af9ae0d42b6c7dbaa13b9dd0d1dd1e8 (BC breaking change if any) - Philox4_32_10 is known to pass the standard TestU01 Big Crush test (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) and hence the quality of random numbers generated isn't an issue when compared to the previously used `curandStateMTGP32`. - I have added a test case in `aten/src/ATen/test/cuda_distributions_test.cu` which verifies that philox offset is incremented properly The benchmark was done on a DGX station with 4 V100s. I modified the script from jcjohnson 's [multinomial benchmark](https://github.com/jcjohnson/pytorch-multinomial-benchmark) to produce this notebook which shows that there is a general speedup with this PR and a regression hasn't been introduced: https://gist.github.com/syed-ahmed/9d26d4e96308aed274d0f2c7be5218ef To reproduce the notebook: - Run https://gist.github.com/syed-ahmed/4208c22c541f1d30ad6a9b1efc1d728f in a container with the current pytorch top of tree with the command: `python uniform_benchmark.py --stats_json before.json` - Apply this diff to the current pytorch top of tree and run the same script in a container with the command: `python uniform_benchmark.py --stats_json after.json` - Run the notebook attached above with the `after.json` and `before.json` in the same directory The effected bandwidth was calculated using the script (thanks to ngimel ): https://gist.github.com/syed-ahmed/f8b7384d642f4bce484228b508b4bc68 Following are the numbers before and after. ``` uniform, size, elements 65536 forward 5.168914794921875e-06 bandwidth (GB/s) 50.71548098597786 uniform, size, elements 131072 forward 5.056858062744141e-06 bandwidth (GB/s) 103.67860705101367 uniform, size, elements 262144 forward 7.164478302001953e-06 bandwidth (GB/s) 146.357621001797 uniform, size, elements 524288 forward 1.1217594146728515e-05 bandwidth (GB/s) 186.9520302275877 uniform, size, elements 1048576 forward 1.923084259033203e-05 bandwidth (GB/s) 218.10297600317384 uniform, size, elements 2097152 forward 3.640890121459961e-05 bandwidth (GB/s) 230.39992200138826 uniform, size, elements 4194304 forward 6.778717041015625e-05 bandwidth (GB/s) 247.49839679819922 uniform, size, elements 8388608 forward 0.00012810707092285157 bandwidth (GB/s) 261.92490202361347 uniform, size, elements 16777216 forward 0.00025241613388061524 bandwidth (GB/s) 265.86598474620627 uniform, size, elements 33554432 forward 0.000497891902923584 bandwidth (GB/s) 269.5720239913193 ``` ``` uniform, size, elements 65536 forward 5.550384521484375e-06 bandwidth (GB/s) 47.22988091821306 uniform, size, elements 131072 forward 5.581378936767578e-06 bandwidth (GB/s) 93.93520954942333 uniform, size, elements 262144 forward 6.165504455566406e-06 bandwidth (GB/s) 170.071404141686 uniform, size, elements 524288 forward 6.3276290893554685e-06 bandwidth (GB/s) 331.4277702414469 uniform, size, elements 1048576 forward 8.509159088134765e-06 bandwidth (GB/s) 492.91639239047356 uniform, size, elements 2097152 forward 1.2989044189453124e-05 bandwidth (GB/s) 645.8218077979443 uniform, size, elements 4194304 forward 2.347707748413086e-05 bandwidth (GB/s) 714.6211452997259 uniform, size, elements 8388608 forward 4.4286251068115234e-05 bandwidth (GB/s) 757.6715389250498 uniform, size, elements 16777216 forward 8.672237396240235e-05 bandwidth (GB/s) 773.8356427961071 uniform, size, elements 33554432 forward 0.00016920566558837892 bandwidth (GB/s) 793.2224227438523 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/20292 Differential Revision: D15277761 Pulled By: ezyang fbshipit-source-id: 8bfe31a01eeed77f0ed6e7ec4d2dda4c6472ecaa
PyTorch is a Python package that provides two high-level features:
You can reuse your favorite Python packages such as NumPy, SciPy and Cython to extend PyTorch when needed.
System | 2.7 | 3.5 | 3.6 |
---|---|---|---|
Linux CPU | — | ||
Linux GPU | — | ||
Windows CPU / GPU | — | — | |
Linux (ppc64le) CPU | — | ||
Linux (ppc64le) GPU | — |
See also the ci.pytorch.org HUD.
At a granular level, PyTorch is a library that consists of the following components:
Component | Description |
---|---|
torch | a Tensor library like NumPy, with strong GPU support |
torch.autograd | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
torch.jit | a compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code |
torch.nn | a neural networks library deeply integrated with autograd designed for maximum flexibility |
torch.multiprocessing | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
torch.utils | DataLoader and other utility functions for convenience |
Usually one uses PyTorch either as:
Elaborating further:
If you use NumPy, then you have used Tensors (a.k.a ndarray).
PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerates the computation by a huge amount.
We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs such as slicing, indexing, math operations, linear algebra, reductions. And they are fast!
PyTorch has a unique way of building neural networks: using and replaying a tape recorder.
Most frameworks such as TensorFlow, Theano, Caffe and CNTK have a static view of the world. One has to build a neural network, and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch.
With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes from several research papers on this topic, as well as current and past work such as torch-autograd, autograd, Chainer, etc.
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research.
PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use NumPy / SciPy / scikit-learn etc. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as Cython and Numba. Our goal is to not reinvent the wheel where appropriate.
PyTorch is designed to be intuitive, linear in thought and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. When you drop into a debugger, or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.
PyTorch has minimal framework overhead. We integrate acceleration libraries such as Intel MKL and NVIDIA (cuDNN, NCCL) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends (TH, THC, THNN, THCUNN) are mature and have been tested for years.
Hence, PyTorch is quite fast – whether you run small or large neural networks.
The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. We've written custom memory allocators for the GPU to make sure that your deep learning models are maximally memory efficient. This enables you to train bigger deep learning models than before.
Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward and with minimal abstractions.
You can write new neural network layers in Python using the torch API or your favorite NumPy-based libraries such as SciPy.
If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. There is no wrapper code that needs to be written. You can see a tutorial here and an example here.
Commands to install from binaries via Conda or pip wheels are on our website: https://pytorch.org
Python wheels for NVIDIA's Jetson Nano, Jetson TX2, and Jetson AGX Xavier are available via the following URLs:
They requires JetPack 4.2 and above and are maintained by @dusty-nv
If you are installing from source, we highly recommend installing an Anaconda environment. You will get a high-quality BLAS library (MKL) and you get a controlled compiler version regardless of your Linux distro.
Once you have Anaconda installed, here are the instructions.
If you want to compile with CUDA support, install
If you want to disable CUDA support, export environment variable NO_CUDA=1
. Other potentially useful environment variables may be found in setup.py
.
If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to are available here
Common
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing
On Linux
# Add LAPACK support for the GPU if needed conda install -c pytorch magma-cuda90 # or [magma-cuda80 | magma-cuda92 | magma-cuda100 ] depending on your cuda version
git clone --recursive https://github.com/pytorch/pytorch cd pytorch # if you are updating an existing checkout git submodule sync git submodule update --init --recursive
On Linux
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py install
On macOS
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
On Windows
At least Visual Studio 2017 Update 3 (version 15.3.3 with the toolset 14.11) and NVTX are needed.
If the version of Visual Studio 2017 is higher than 15.4.5, installing of “VC++ 2017 version 15.4 v14.11 toolset” is strongly recommended.
If the version of Visual Studio 2017 is lesser than 15.3.3, please update Visual Studio 2017 to the latest version along with installing “VC++ 2017 version 15.4 v14.11 toolset”.
There is no guarantee of the correct building with VC++ 2017 toolsets, others than version 15.4 v14.11.
“VC++ 2017 version 15.4 v14.11 toolset” might be installed onto already installed Visual Studio 2017 by running its installation once again and checking the corresponding checkbox under “Individual components”/“Compilers, build tools, and runtimes”.
For building against CUDA 8.0 Visual Studio 2015 Update 3 (version 14.0), and the patch are needed to be installed too. The details of the patch can be found here.
NVTX is a part of CUDA distributive, where it is called “Nsight Compute”. For installing it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox. Be sure that CUDA with Nsight Compute is installed after Visual Studio 2017.
cmd REM [Optional] The following two lines are needed for Python 2.7, but the support for it is very experimental. set MSSdk=1 set FORCE_PY27_BUILD=1 REM [Optional] As for CUDA 8, VS2015 Update 3 is required; use the following line. set "CUDAHOSTCXX=%VS140COMNTOOLS%..\..\VC\bin\amd64\cl.exe" set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 set DISTUTILS_USE_SDK=1 REM Run "Visual Studio 2017 Developer Command Prompt" for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=14.11 python setup.py install
Dockerfile is supplied to build images with cuda support and cudnn v7. You can pass -e PYTHON_VERSION=x.y
flag to specify which Python version is to be used by Miniconda, or leave it unset to use the default. Build from pytorch repo directory as docker needs to copy git repo into docker filesystem while building the image.
docker build -t pytorch -f docker/pytorch/Dockerfile .
You can also pull a pre-built docker image from Docker Hub and run with nvidia-docker, but this is not currently maintained and will pull PyTorch 0.2.
nvidia-docker run --rm -ti --ipc=host pytorch/pytorch:latest
Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host
or --shm-size
command line options to nvidia-docker run
.
To build documentation in various formats, you will need Sphinx and the readthedocs theme.
cd docs/ pip install -r requirements.txt
You can then build the documentation by running make <format>
from the docs/
folder. Run make
to get a list of all available output formats.
Installation instructions and binaries for previous PyTorch versions may be found on our website.
Three pointers to get you started:
PyTorch has a 90 day release cycle (major releases). Please let us know if you encounter a bug by filing an issue.
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
PyTorch is a community driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by Adam Paszke, Sam Gross, Soumith Chintala and Gregory Chanan with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Kopf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.
Note: this project is unrelated to hughperkins/pytorch with the same name. Hugh is a valuable contributor in the Torch community and has helped with many things Torch and PyTorch.
PyTorch is BSD-style licensed, as found in the LICENSE file.