Move lshift to Aten (#31566)

Summary:
VitalyFedyunin , this PR is about move lshift to Aten.
Benchmark script :
```
import timeit
import torch
torch.manual_seed(1)

for n, t in [(10, 100000),(1000, 10000)]:
    print('__lshift__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a << b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t))
        for dtype in ('torch.float32', 'torch.float64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a << b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randn({n}, dtype = {dtype}, device="{device}"); b = torch.randn({n}, dtype = {dtype}, device="{device}")', number=t))

for n, t in [(10, 100000),(1000, 10000)]:
    print('__ilshift__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a << b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
        for dtype in ('torch.float32', 'torch.float64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a << b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randn({n}, dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
```
Device: **Tesla P100, skx-8180**
Cuda verison: **9.0.176**

Before:
```
__lshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.31618343852460384
device: cpu, dtype: torch.uint8, 100000 times           0.31258584931492805
device: cpu, dtype: torch.int16, 100000 times           0.3140896391123533
device: cpu, dtype: torch.int32, 100000 times           0.34389012958854437
device: cpu, dtype: torch.int64, 100000 times           0.339566046372056
device: cpu, dtype: torch.float32, 100000 times         0.4180623721331358
device: cpu, dtype: torch.float64, 100000 times         0.4165227338671684
device: cuda, dtype: torch.int8, 100000 times           1.7851383443921804
device: cuda, dtype: torch.uint8, 100000 times          1.7842160519212484
device: cuda, dtype: torch.int16, 100000 times          1.789359962567687
device: cuda, dtype: torch.int32, 100000 times          1.7822618428617716
device: cuda, dtype: torch.int64, 100000 times          1.7968465769663453
device: cuda, dtype: torch.float32, 100000 times                1.8066061967983842
device: cuda, dtype: torch.float64, 100000 times                1.8046843251213431
__lshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.04618230368942022
device: cpu, dtype: torch.uint8, 10000 times            0.04634759668260813
device: cpu, dtype: torch.int16, 10000 times            0.040676115080714226
device: cpu, dtype: torch.int32, 10000 times            0.04404774494469166
device: cpu, dtype: torch.int64, 10000 times            0.04511771444231272
device: cpu, dtype: torch.float32, 10000 times          0.6887832451611757
device: cpu, dtype: torch.float64, 10000 times          0.5559549620375037
device: cuda, dtype: torch.int8, 10000 times            0.17996764183044434
device: cuda, dtype: torch.uint8, 10000 times           0.17970609478652477
device: cuda, dtype: torch.int16, 10000 times           0.17873135022819042
device: cuda, dtype: torch.int32, 10000 times           0.1781835313886404
device: cuda, dtype: torch.int64, 10000 times           0.17846618220210075
device: cuda, dtype: torch.float32, 10000 times         0.18056879844516516
device: cuda, dtype: torch.float64, 10000 times         0.18132662680000067
__ilshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.61110960226506
device: cpu, dtype: torch.uint8, 100000 times           0.6333359787240624
device: cpu, dtype: torch.int16, 100000 times           0.6345370784401894
device: cpu, dtype: torch.int32, 100000 times           0.6470990972593427
device: cpu, dtype: torch.int64, 100000 times           0.6587044578045607
device: cpu, dtype: torch.float32, 100000 times         0.7269002720713615
device: cpu, dtype: torch.float64, 100000 times         0.7217964073643088
device: cuda, dtype: torch.int8, 100000 times           1.9880435159429908
device: cuda, dtype: torch.uint8, 100000 times          1.986489498987794
device: cuda, dtype: torch.int16, 100000 times          2.0059875370934606
device: cuda, dtype: torch.int32, 100000 times          1.995262237265706
device: cuda, dtype: torch.int64, 100000 times          1.9974954994395375
device: cuda, dtype: torch.float32, 100000 times                2.00442770216614
device: cuda, dtype: torch.float64, 100000 times                2.009664717130363
__ilshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.08199594635516405
device: cpu, dtype: torch.uint8, 10000 times            0.08096733782440424
device: cpu, dtype: torch.int16, 10000 times            0.0734213450923562
device: cpu, dtype: torch.int32, 10000 times            0.0769620593637228
device: cpu, dtype: torch.int64, 10000 times            0.08650507684797049
device: cpu, dtype: torch.float32, 10000 times          0.7196345143020153
device: cpu, dtype: torch.float64, 10000 times          0.597336508333683
device: cuda, dtype: torch.int8, 10000 times            0.19723015930503607
device: cuda, dtype: torch.uint8, 10000 times           0.19754122477024794
device: cuda, dtype: torch.int16, 10000 times           0.19710093270987272
device: cuda, dtype: torch.int32, 10000 times           0.19611249305307865
device: cuda, dtype: torch.int64, 10000 times           0.19750046730041504
device: cuda, dtype: torch.float32, 10000 times         0.19680574722588062
device: cuda, dtype: torch.float64, 10000 times         0.19689027685672045
```
After:
```
__lshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.3031281465664506
device: cpu, dtype: torch.uint8, 100000 times           0.30772678554058075
device: cpu, dtype: torch.int16, 100000 times           0.3088294789195061
device: cpu, dtype: torch.int32, 100000 times           0.30907699652016163
device: cpu, dtype: torch.int64, 100000 times           0.31315001379698515
device: cpu, dtype: torch.float32, 100000 times         0.38823566399514675
device: cpu, dtype: torch.float64, 100000 times         0.39300001971423626
device: cuda, dtype: torch.int8, 100000 times           1.3225595457479358
device: cuda, dtype: torch.uint8, 100000 times          1.31739442050457
device: cuda, dtype: torch.int16, 100000 times          1.3198596313595772
device: cuda, dtype: torch.int32, 100000 times          1.309600466862321
device: cuda, dtype: torch.int64, 100000 times          1.3264533821493387
device: cuda, dtype: torch.float32, 100000 times                1.3377520674839616
device: cuda, dtype: torch.float64, 100000 times                1.3343619462102652
__lshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.02718757465481758
device: cpu, dtype: torch.uint8, 10000 times            0.02701799664646387
device: cpu, dtype: torch.int16, 10000 times            0.025483975186944008
device: cpu, dtype: torch.int32, 10000 times            0.025557605549693108
device: cpu, dtype: torch.int64, 10000 times            0.026179466396570206
device: cpu, dtype: torch.float32, 10000 times          0.0962932649999857
device: cpu, dtype: torch.float64, 10000 times          0.1611471576616168
device: cuda, dtype: torch.int8, 10000 times            0.13165222201496363
device: cuda, dtype: torch.uint8, 10000 times           0.13358880020678043
device: cuda, dtype: torch.int16, 10000 times           0.1342075066640973
device: cuda, dtype: torch.int32, 10000 times           0.1328689968213439
device: cuda, dtype: torch.int64, 10000 times           0.13336248509585857
device: cuda, dtype: torch.float32, 10000 times         0.1345295710489154
device: cuda, dtype: torch.float64, 10000 times         0.14084953162819147
__ilshift__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.19080814253538847
device: cpu, dtype: torch.uint8, 100000 times           0.18541878275573254
device: cpu, dtype: torch.int16, 100000 times           0.19136024825274944
device: cpu, dtype: torch.int32, 100000 times           0.1916898973286152
device: cpu, dtype: torch.int64, 100000 times           0.1973192635923624
device: cpu, dtype: torch.float32, 100000 times         0.2668355852365494
device: cpu, dtype: torch.float64, 100000 times         0.24472137168049812
device: cuda, dtype: torch.int8, 100000 times           1.3581306440755725
device: cuda, dtype: torch.uint8, 100000 times          1.3522163443267345
device: cuda, dtype: torch.int16, 100000 times          1.366145665757358
device: cuda, dtype: torch.int32, 100000 times          1.3674909211695194
device: cuda, dtype: torch.int64, 100000 times          1.3734915973618627
device: cuda, dtype: torch.float32, 100000 times                1.3831533305346966
device: cuda, dtype: torch.float64, 100000 times                1.396162535995245
__ilshift__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.02847585454583168
device: cpu, dtype: torch.uint8, 10000 times            0.02960751298815012
device: cpu, dtype: torch.int16, 10000 times            0.028516249731183052
device: cpu, dtype: torch.int32, 10000 times            0.02842544950544834
device: cpu, dtype: torch.int64, 10000 times            0.029186096973717213
device: cpu, dtype: torch.float32, 10000 times          0.0999628696590662
device: cpu, dtype: torch.float64, 10000 times          0.16676222812384367
device: cuda, dtype: torch.int8, 10000 times            0.13856443110853434
device: cuda, dtype: torch.uint8, 10000 times           0.13766566663980484
device: cuda, dtype: torch.int16, 10000 times           0.13652489613741636
device: cuda, dtype: torch.int32, 10000 times           0.13678150344640017
device: cuda, dtype: torch.int64, 10000 times           0.13749946560710669
device: cuda, dtype: torch.float32, 10000 times         0.13879029918462038
device: cuda, dtype: torch.float64, 10000 times         0.14587809145450592
```

Fix https://github.com/pytorch/pytorch/issues/24510 #24514 https://github.com/pytorch/pytorch/issues/24657  https://github.com/pytorch/pytorch/issues/24661
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31566

Differential Revision: D19314251

Pulled By: ezyang

fbshipit-source-id: 52df17b2c18ef1880374c6dbcf18fb1118086552
14 files changed
tree: 63438a8743f250ac09d8d797c189591c50a4fa3a
  1. .circleci/
  2. .ctags.d/
  3. .github/
  4. .jenkins/
  5. android/
  6. aten/
  7. benchmarks/
  8. binaries/
  9. c10/
  10. caffe2/
  11. cmake/
  12. docker/
  13. docs/
  14. ios/
  15. modules/
  16. scripts/
  17. submodules/
  18. test/
  19. third_party/
  20. tools/
  21. torch/
  22. .clang-format
  23. .clang-tidy
  24. .flake8
  25. .gitattributes
  26. .gitignore
  27. .gitmodules
  28. .python2
  29. .travis.aten.yml
  30. CITATION
  31. CMakeLists.txt
  32. CODEOWNERS
  33. CONTRIBUTING.md
  34. LICENSE
  35. Makefile
  36. mypy-files.txt
  37. mypy-README.md
  38. mypy.ini
  39. NOTICE
  40. README.md
  41. requirements.txt
  42. setup.py
  43. ubsan.supp
  44. version.txt
README.md

PyTorch Logo


PyTorch is a Python package that provides two high-level features:

  • Tensor computation (like NumPy) with strong GPU acceleration
  • Deep neural networks built on a tape-based autograd system

You can reuse your favorite Python packages such as NumPy, SciPy and Cython to extend PyTorch when needed.

System2.73.53.6
Linux CPUBuild StatusBuild Status
Linux GPUBuild StatusBuild Status
Windows CPU / GPUBuild Status
Linux (ppc64le) CPUBuild StatusBuild Status
Linux (ppc64le) GPUBuild StatusBuild Status

See also the ci.pytorch.org HUD.

More About PyTorch

At a granular level, PyTorch is a library that consists of the following components:

ComponentDescription
torcha Tensor library like NumPy, with strong GPU support
torch.autograda tape-based automatic differentiation library that supports all differentiable Tensor operations in torch
torch.jita compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code
torch.nna neural networks library deeply integrated with autograd designed for maximum flexibility
torch.multiprocessingPython multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training
torch.utilsDataLoader and other utility functions for convenience

Usually PyTorch is used either as:

  • a replacement for NumPy to use the power of GPUs.
  • a deep learning research platform that provides maximum flexibility and speed.

Elaborating further:

A GPU-Ready Tensor Library

If you use NumPy, then you have used Tensors (a.k.a ndarray).

Tensor illustration

PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerates the computation by a huge amount.

We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs such as slicing, indexing, math operations, linear algebra, reductions. And they are fast!

Dynamic Neural Networks: Tape-Based Autograd

PyTorch has a unique way of building neural networks: using and replaying a tape recorder.

Most frameworks such as TensorFlow, Theano, Caffe and CNTK have a static view of the world. One has to build a neural network, and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch.

With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes from several research papers on this topic, as well as current and past work such as torch-autograd, autograd, Chainer, etc.

While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research.

Dynamic graph

Python First

PyTorch is not a Python binding into a monolithic C++ framework. It is built to be deeply integrated into Python. You can use it naturally like you would use NumPy / SciPy / scikit-learn etc. You can write your new neural network layers in Python itself, using your favorite libraries and use packages such as Cython and Numba. Our goal is to not reinvent the wheel where appropriate.

Imperative Experiences

PyTorch is designed to be intuitive, linear in thought and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. When you drop into a debugger, or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines.

Fast and Lean

PyTorch has minimal framework overhead. We integrate acceleration libraries such as Intel MKL and NVIDIA (cuDNN, NCCL) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends (TH, THC, THNN, THCUNN) are mature and have been tested for years.

Hence, PyTorch is quite fast – whether you run small or large neural networks.

The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. We've written custom memory allocators for the GPU to make sure that your deep learning models are maximally memory efficient. This enables you to train bigger deep learning models than before.

Extensions Without Pain

Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward and with minimal abstractions.

You can write new neural network layers in Python using the torch API or your favorite NumPy-based libraries such as SciPy.

If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. No wrapper code needs to be written. You can see a tutorial here and an example here.

Installation

Binaries

Commands to install from binaries via Conda or pip wheels are on our website: https://pytorch.org

NVIDIA Jetson platforms

Python wheels for NVIDIA's Jetson Nano, Jetson TX2, and Jetson AGX Xavier are available via the following URLs:

They require JetPack 4.2 and above, and @dusty-nv maintains them

From Source

If you are installing from source, you will need a C++14 compiler. Also, we highly recommend installing an Anaconda environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro.

Once you have Anaconda installed, here are the instructions.

If you want to compile with CUDA support, install

If you want to disable CUDA support, export environment variable USE_CUDA=0. Other potentially useful environment variables may be found in setup.py.

If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to are available here

Install Dependencies

Common (only install typing for Python <3.5)

conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing

On Linux

# Add LAPACK support for the GPU if needed
conda install -c pytorch magma-cuda90 # or [magma-cuda92 | magma-cuda100 | magma-cuda101 ] depending on your cuda version

Get the PyTorch Source

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
# if you are updating an existing checkout
git submodule sync
git submodule update --init --recursive

Install PyTorch

On Linux

export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
python setup.py install

On macOS

export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install

Each CUDA version only supports one particular XCode version. The following combinations have been reported to work with PyTorch.

CUDA versionXCode version
10.0XCode 9.4
10.1XCode 10.1

On Windows

At least Visual Studio 2017 Update 3 (version 15.3.3 with the toolset 14.11) and NVTX are needed.

If the version of Visual Studio 2017 is higher than 15.4.5, installing of “VC++ 2017 version 15.4 v14.11 toolset” is strongly recommended.
If the version of Visual Studio 2017 is lesser than 15.3.3, please update Visual Studio 2017 to the latest version along with installing “VC++ 2017 version 15.4 v14.11 toolset”.
There is no guarantee of the correct building with VC++ 2017 toolsets, others than version 15.4 v14.11.
“VC++ 2017 version 15.4 v14.11 toolset” might be installed onto already installed Visual Studio 2017 by running its installation once again and checking the corresponding checkbox under “Individual components”/“Compilers, build tools, and runtimes”.

NVTX is a part of CUDA distributive, where it is called “Nsight Compute”. To install it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox. Be sure that CUDA with Nsight Compute is installed after Visual Studio 2017.

Currently VS 2017, VS 2019 and Ninja are supported as the generator of CMake. If ninja.exe is detected in PATH, then Ninja will be used as the default generator, otherwise it will use VS 2017.
If Ninja is selected as the generator, the latest MSVC which is newer than VS 2015 (14.0) will get selected as the underlying toolchain if you have Python > 3.5, otherwise VS 2015 will be selected so you'll have to activate the environment. If you use CMake <= 3.14.2 and has VS 2019 installed, then even if you specify VS 2017 as the generator, VS 2019 will get selected as the generator.

CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like nvcc fatal : Host compiler targets unsupported OS. For this kind of problem, please install the corresponding VS toolchain in the table below and then you can either specify the toolset during activation (recommended) or set CUDAHOSTCXX to override the cuda host compiler (not recommended if there are big version differences).

CUDA versionNewest supported VS version
9.0 / 9.1Visual Studio 2017 Update 4 (15.4) (_MSC_VER <= 1911)
9.2Visual Studio 2017 Update 5 (15.5) (_MSC_VER <= 1912)
10.0Visual Studio 2017 (15.X) (_MSC_VER < 1920)
10.1Visual Studio 2019 (16.X) (_MSC_VER < 1930)
cmd
:: [Optional] Only add the next two lines if you need Python 2.7. If you use Python 3, ignore these two lines.
set MSSdk=1
set FORCE_PY27_BUILD=1

:: [Optional] If you want to build with VS 2019 generator, please change the value in the next line to `Visual Studio 16 2019`.
:: Note: This value is useless if Ninja is detected. However, you can force that by using `set USE_NINJA=OFF`.
set CMAKE_GENERATOR=Visual Studio 15 2017

:: Read the content in the previous section carefully before you proceed.
:: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block.
:: "Visual Studio 2017 Developer Command Prompt" will be run automatically.
:: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator.
:: It's an essential step if you use Python 3.5.
set CMAKE_GENERATOR_TOOLSET_VERSION=14.11
set DISTUTILS_USE_SDK=1
for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION%

:: [Optional] If you want to override the cuda host compiler
set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Tools\MSVC\14.11.25503\bin\HostX64\x64\cl.exe

python setup.py install

Adjust Build Options (Optional)

You can adjust the configuration of cmake variables optionally (without building first), by doing the following. For example, adjusting the pre-detected directories for CuDNN or BLAS can be done with such a step.

On Linux

export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
python setup.py build --cmake-only
ccmake build  # or cmake-gui build

On macOS

export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build --cmake-only
ccmake build  # or cmake-gui build

Docker Image

Dockerfile is supplied to build images with cuda support and cudnn v7. You can pass -e PYTHON_VERSION=x.y flag to specify which Python version is to be used by Miniconda, or leave it unset to use the default. Build from pytorch repo directory as docker needs to copy git repo into docker filesystem while building the image.

docker build -t pytorch -f docker/pytorch/Dockerfile .  # [optional] --build-arg WITH_TORCHVISION=0

You can also pull a pre-built docker image from Docker Hub and run with nvidia-docker, but this is not currently maintained and will pull PyTorch 0.2.

nvidia-docker run --rm -ti --ipc=host pytorch/pytorch:latest

Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you should increase shared memory size either with --ipc=host or --shm-size command line options to nvidia-docker run.

Building the Documentation

To build documentation in various formats, you will need Sphinx and the readthedocs theme.

cd docs/
pip install -r requirements.txt

You can then build the documentation by running make <format> from the docs/ folder. Run make to get a list of all available output formats.

Previous Versions

Installation instructions and binaries for previous PyTorch versions may be found on our website.

Getting Started

Three pointers to get you started:

Communication

  • forums: discuss implementations, research, etc. https://discuss.pytorch.org
  • GitHub issues: bug reports, feature requests, install issues, RFCs, thoughts, etc.
  • Slack: The PyTorch Slack hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration etc. If you are a beginner looking for help, the primary medium is PyTorch Forums. If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1
  • newsletter: no-noise, one-way email newsletter with important announcements about pytorch. You can sign-up here: https://eepurl.com/cbG0rv

Releases and Contributing

PyTorch has a 90 day release cycle (major releases). Please let us know if you encounter a bug by filing an issue.

We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.

If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.

The Team

PyTorch is a community driven project with several skillful engineers and researchers contributing to it.

PyTorch is currently maintained by Adam Paszke, Sam Gross, Soumith Chintala and Gregory Chanan with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito.

Note: this project is unrelated to hughperkins/pytorch with the same name. Hugh is a valuable contributor in the Torch community and has helped with many things Torch and PyTorch.

License

PyTorch is BSD-style licensed, as found in the LICENSE file.