commit | a2afad2b696e77ef71855aa0284abaf7eecb3c67 | [log] [tgz] |
---|---|---|
author | mruberry <mruberry@nvidia.com> | Thu Sep 06 21:37:12 2018 -0700 |
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | Thu Sep 06 21:39:44 2018 -0700 |
tree | d8ec6095faa4a359e4c9dc4c5ddffed628537c13 | |
parent | b3b1e7624dc7c48309bfd38f527b3bdabf920631 [diff] |
Improves ATen CUDAEvent (#11293) Summary: After submitting PR #9726, PR #10581 created a different CUDAEvent class. The CUDAEvent proposed in #9726 was similar to the c10d::CUDAEvent class with additional testing and functionality. In particular, it was movable but not copyable. The CUDAEvent created by #10581 is refcounted and copyable. This PR retains the refcounting of the latter PR while fixing several bugs, adding tests, and extending the functionality to support testing and usage like in PR #8354. In particular, this PR: - Adds set_device() to CUDAContext - Adds three CUDAEvent tests to stream_test.cpp - Fixes three bugs: - Refcounting was broken. Destroying an of the RAIIs holding a particular CUDAEvent would destroy the event UNLESS it was the last RAII (the check was backwards). - Moving an event would cause a segfault. - Events were not destroyed on the device they were created on. See PR #9415 (pietern) - Adds the happened() and recordOnce() functions - Changes the record() functions to not be const - Adds additional assertions to verify correctness This PR does not: - Make c10d use the ATen CUDAEvent (this is appropriate for a separate PR) Whether events should be refcounted is an interesting question. It adds some atomic operations and makes event creation eager. Making events movable but not copyable (like the c10d events) avoids these costs and allows events to be lazily constructed. Lazy construction is preferable when working with containers (like std::array or std::vector) and because the event's device can be set automatically to the first stream it's recorded on. With eager construction the user is required to understand that events have a device and acquire the device of the stream the event will be recorded on upfront. This can be seen here: https://github.com/pytorch/pytorch/blob/542aadd9a7609892e207c1e15de08a975b697752/aten/src/ATen/native/cudnn/RNN.cpp#L1130-L1132 and that file is the only one which currently uses the ATen CUDAEvent. Refcounting does allow single writer multi-reader scenarios, although these scenarios can be also be supported by providing indirect access to the underlying CUDAEvent. I believe all current and planned usage scenarios do not require refcounting, and if desired I can update this PR to remove refcounting and make the ATen event movable but not copyable like the c10d event. I think not refcounting is preferable because it can improve performance, ease usability, and simplify the code (as seen with two of the above bugs). I have decided to separate this from PR #8354 since while it's required for PR #8354 the changes are, clearly, of independent interest. PR #8354 has a new dependency on this one, however. I am closing PR #9726 in favor of this PR. apaszke ezyang pietern Pull Request resolved: https://github.com/pytorch/pytorch/pull/11293 Differential Revision: D9665836 Pulled By: soumith fbshipit-source-id: a1513fa4f9761e2f304d126e402f6b6950e1c1d2
PyTorch is a Python package that provides two high-level features:
You can reuse your favorite Python packages such as NumPy, SciPy and Cython to extend PyTorch when needed.
We are in an early-release beta. Expect some adventures and rough edges.
System | 2.7 | 3.5 |
---|---|---|
Linux CPU | ||
Linux GPU | ||
Windows GPU | — |
See also the ci.pytorch.org HUD.
At a granular level, PyTorch is a library that consists of the following components:
Component | Description |
---|---|
torch | a Tensor library like NumPy, with strong GPU support |
torch.autograd | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
torch.nn | a neural networks library deeply integrated with autograd designed for maximum flexibility |
torch.multiprocessing | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
torch.utils | DataLoader, Trainer and other utility functions for convenience |
torch.legacy(.nn/.optim) | legacy code that has been ported over from torch for backward compatibility reasons |
Usually one uses PyTorch either as:
Elaborating further:
If you use NumPy, then you have used Tensors (a.k.a ndarray).
PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerates the computation by a huge amount.
We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs such as slicing, indexing, math operations, linear algebra, reductions. And they are fast!
PyTorch has a unique way of building neural networks: using and replaying a tape recorder.
Most frameworks such as TensorFlow, Theano, Caffe and CNTK have a static view of the world. One has to build a neural network, and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch.
With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes from several research papers on this topic, as well as current and past work such as torch-autograd, autograd, Chainer, etc.
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research.
PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use NumPy / SciPy / scikit-learn etc. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as Cython and Numba. Our goal is to not reinvent the wheel where appropriate.
PyTorch is designed to be intuitive, linear in thought and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. When you drop into a debugger, or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.
PyTorch has minimal framework overhead. We integrate acceleration libraries such as Intel MKL and NVIDIA (cuDNN, NCCL) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends (TH, THC, THNN, THCUNN) are mature and have been tested for years.
Hence, PyTorch is quite fast – whether you run small or large neural networks.
The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. We've written custom memory allocators for the GPU to make sure that your deep learning models are maximally memory efficient. This enables you to train bigger deep learning models than before.
Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward and with minimal abstractions.
You can write new neural network layers in Python using the torch API or your favorite NumPy-based libraries such as SciPy.
If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. There is no wrapper code that needs to be written. You can see a tutorial here and an example here.
Commands to install from binaries via Conda or pip wheels are on our website:
If you are installing from source, we highly recommend installing an Anaconda environment. You will get a high-quality BLAS library (MKL) and you get a controlled compiler version regardless of your Linux distro.
Once you have Anaconda installed, here are the instructions.
If you want to compile with CUDA support, install
If you want to disable CUDA support, export environment variable NO_CUDA=1
. Other potentially useful environment variables may be found in setup.py
.
If you want to build on Windows, Visual Studio 2017 14.11 toolset and NVTX are also needed. Especially, for CUDA 8 build on Windows, there will be an additional requirement for VS 2015 Update 3 and a patch for it. The details of the patch can be found out here.
On Linux
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory] # Install basic dependencies conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing conda install -c mingfeima mkldnn # Add LAPACK support for the GPU conda install -c pytorch magma-cuda80 # or magma-cuda90 if CUDA 9
On macOS
export CMAKE_PREFIX_PATH=[anaconda root directory] conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
On Windows
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
git clone --recursive https://github.com/pytorch/pytorch cd pytorch
On Linux
python setup.py install
On macOS
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
On Windows
set "VS150COMNTOOLS=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Auxiliary\Build" set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 set DISTUTILS_USE_SDK=1 REM The following two lines are needed for Python 2.7, but the support for it is very experimental. set MSSdk=1 set FORCE_PY27_BUILD=1 REM As for CUDA 8, VS2015 Update 3 is also required to build PyTorch. Use the following two lines. set "PREBUILD_COMMAND=%VS140COMNTOOLS%\..\..\VC\vcvarsall.bat" set PREBUILD_COMMAND_ARGS=x64 call "%VS150COMNTOOLS%\vcvarsall.bat" x64 -vcvars_ver=14.11 python setup.py install
Dockerfile is supplied to build images with cuda support and cudnn v7. You can pass -e PYTHON_VERSION=x.y flag to specificy which python to be used by Miniconda, or leave it unset to use the default. Build as usual
docker build -t pytorch -f docker/pytorch/Dockerfile .
You can also pull a pre-built docker image from Docker Hub and run with nvidia-docker, but this is not currently maintained and will pull PyTorch 0.2.
nvidia-docker run --rm -ti --ipc=host pytorch/pytorch:latest
Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host
or --shm-size
command line options to nvidia-docker run
.
Installation instructions and binaries for previous PyTorch versions may be found on our website.
Three pointers to get you started:
PyTorch has a 90 day release cycle (major releases). Its current state is Beta, we expect no obvious bugs. Please let us know if you encounter a bug by filing an issue.
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
PyTorch is a community driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by Adam Paszke, Sam Gross, Soumith Chintala and Gregory Chanan with major contributions coming from 10s of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Kopf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.
Note: this project is unrelated to hughperkins/pytorch with the same name. Hugh is a valuable contributor in the Torch community and has helped with many things Torch and PyTorch.