blob: f876b07cd0cec8e48bab9f02d74d8e51b80df940 [file] [log] [blame]
from common import TestCase, run_tests
import torch
from torch.autograd import Variable
class TestIndexing(TestCase):
def test_single_int(self):
v = Variable(torch.randn(5, 7, 3))
self.assertEqual(v[4].shape, (7, 3))
def test_multiple_int(self):
v = Variable(torch.randn(5, 7, 3))
self.assertEqual(v[4].shape, (7, 3))
self.assertEqual(v[4, :, 1].shape, (7,))
def test_none(self):
v = Variable(torch.randn(5, 7, 3))
self.assertEqual(v[None].shape, (1, 5, 7, 3))
self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
def test_step(self):
v = Variable(torch.arange(10))
self.assertEqual(v[::1], v)
self.assertEqual(v[::2].data.tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].data.tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].data.tolist(), [0])
self.assertEqual(v[1:6:2].data.tolist(), [1, 3, 5])
def test_step_assignment(self):
v = Variable(torch.zeros(4, 4))
v[0, 1::2] = Variable(torch.Tensor([3, 4]))
self.assertEqual(v[0].data.tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].data.sum(), 0)
def test_byte_mask(self):
v = Variable(torch.randn(5, 7, 3))
mask = Variable(torch.ByteTensor([1, 0, 1, 1, 0]))
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
v = Variable(torch.Tensor([1]))
self.assertEqual(v[v == 0], Variable(torch.Tensor()))
def test_multiple_byte_mask(self):
v = Variable(torch.randn(5, 7, 3))
# note: these broadcast together and are transposed to the first dim
mask1 = Variable(torch.ByteTensor([1, 0, 1, 1, 0]))
mask2 = Variable(torch.ByteTensor([1, 1, 1]))
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask2d(self):
v = Variable(torch.randn(5, 7, 3))
c = Variable(torch.randn(5, 7))
num_ones = (c > 0).data.sum()
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
def test_int_indices(self):
v = Variable(torch.randn(5, 7, 3))
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
def test_int_indices2d(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
rows = Variable(torch.LongTensor([[0, 0], [3, 3]]))
columns = Variable(torch.LongTensor([[0, 2], [0, 2]]))
self.assertEqual(x[rows, columns].data.tolist(), [[0, 2], [9, 11]])
def test_int_indices_broadcast(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
rows = Variable(torch.LongTensor([0, 3]))
columns = Variable(torch.LongTensor([0, 2]))
result = x[rows[:, None], columns]
self.assertEqual(result.data.tolist(), [[0, 2], [9, 11]])
def test_empty_index(self):
x = Variable(torch.arange(0, 12).view(4, 3))
idx = Variable(torch.LongTensor())
self.assertEqual(x[idx].numel(), 0)
def test_basic_advanced_combined(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
self.assertEqual(x[1:2, 1:3].data.tolist(), [[4, 5]])
# Check that it is a copy
unmodified = x.clone()
x[1:2, [1, 2]].zero_()
self.assertEqual(x, unmodified)
# But assignment should modify the original
unmodified = x.clone()
x[1:2, [1, 2]] = 0
self.assertNotEqual(x, unmodified)
def test_int_assignment(self):
x = Variable(torch.arange(0, 4).view(2, 2))
x[1] = 5
self.assertEqual(x.data.tolist(), [[0, 1], [5, 5]])
x = Variable(torch.arange(0, 4).view(2, 2))
x[1] = Variable(torch.arange(5, 7))
self.assertEqual(x.data.tolist(), [[0, 1], [5, 6]])
def test_byte_tensor_assignment(self):
x = Variable(torch.arange(0, 16).view(4, 4))
b = Variable(torch.ByteTensor([True, False, True, False]))
value = Variable(torch.Tensor([3, 4, 5, 6]))
x[b] = value
self.assertEqual(x[0], value)
self.assertEqual(x[1].data, torch.arange(4, 8))
self.assertEqual(x[2], value)
self.assertEqual(x[3].data, torch.arange(12, 16))
def tensor(*args, **kwargs):
return Variable(torch.Tensor(*args, **kwargs))
def byteTensor(data):
return Variable(torch.ByteTensor(data))
def ones(*args):
return Variable(torch.ones(*args))
def zeros(*args):
return Variable(torch.zeros(*args))
# The tests below are from NumPy test_indexing.py with some modifications to
# make them compatible with PyTorch. It's licensed under the BDS license below:
#
# Copyright (c) 2005-2017, NumPy Developers.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# * Neither the name of the NumPy Developers nor the names of any
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class NumpyTests(TestCase):
def test_index_no_floats(self):
a = Variable(torch.Tensor([[[5]]]))
self.assertRaises(IndexError, lambda: a[0.0])
self.assertRaises(IndexError, lambda: a[0, 0.0])
self.assertRaises(IndexError, lambda: a[0.0, 0])
self.assertRaises(IndexError, lambda: a[0.0, :])
self.assertRaises(IndexError, lambda: a[:, 0.0])
self.assertRaises(IndexError, lambda: a[:, 0.0, :])
self.assertRaises(IndexError, lambda: a[0.0, :, :])
self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
self.assertRaises(IndexError, lambda: a[-1.4])
self.assertRaises(IndexError, lambda: a[0, -1.4])
self.assertRaises(IndexError, lambda: a[-1.4, 0])
self.assertRaises(IndexError, lambda: a[-1.4, :])
self.assertRaises(IndexError, lambda: a[:, -1.4])
self.assertRaises(IndexError, lambda: a[:, -1.4, :])
self.assertRaises(IndexError, lambda: a[-1.4, :, :])
self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
# self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
# self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
def test_none_index(self):
# `None` index adds newaxis
a = tensor([1, 2, 3])
self.assertEqual(a[None].dim(), a.dim() + 1)
def test_empty_tuple_index(self):
# Empty tuple index creates a view
a = tensor([1, 2, 3])
self.assertEqual(a[()], a)
self.assertEqual(a[()].data_ptr(), a.data_ptr())
def test_empty_fancy_index(self):
# Empty list index creates an empty array
a = tensor([1, 2, 3])
self.assertEqual(a[[]], Variable(torch.Tensor()))
b = tensor([]).long()
self.assertEqual(a[[]], Variable(torch.LongTensor()))
b = tensor([]).float()
self.assertRaises(RuntimeError, lambda: a[b])
def test_ellipsis_index(self):
a = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
self.assertIsNot(a[...], a)
self.assertEqual(a[...], a)
# `a[...]` was `a` in numpy <1.9.
self.assertEqual(a[...].data_ptr(), a.data_ptr())
# Slicing with ellipsis can skip an
# arbitrary number of dimensions
self.assertEqual(a[0, ...], a[0])
self.assertEqual(a[0, ...], a[0, :])
self.assertEqual(a[..., 0], a[:, 0])
# Slicing with ellipsis always results
# in an array, not a scalar
self.assertEqual(a[0, ..., 1], tensor([2]))
# Assignment with `(Ellipsis,)` on 0-d arrays
# b = np.array(1)
# b[(Ellipsis,)] = 2
# self.assertEqual(b, 2)
def test_single_int_index(self):
# Single integer index selects one row
a = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
self.assertEqual(a[0].data, [1, 2, 3])
self.assertEqual(a[-1].data, [7, 8, 9])
# Index out of bounds produces IndexError
self.assertRaises(IndexError, a.__getitem__, 1 << 30)
# Index overflow produces Exception NB: different exception type
self.assertRaises(Exception, a.__getitem__, 1 << 64)
def test_single_bool_index(self):
# Single boolean index
a = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
self.assertEqual(a[True], a[None])
self.assertEqual(a[False], a[None][0:0])
def test_boolean_shape_mismatch(self):
arr = ones((5, 4, 3))
# TODO: prefer IndexError
index = byteTensor([True])
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
index = byteTensor([False] * 6)
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
index = Variable(torch.ByteTensor(4, 4)).zero_()
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[(slice(None), index)])
def test_boolean_indexing_onedim(self):
# Indexing a 2-dimensional array with
# boolean array of length one
a = tensor([[0., 0., 0.]])
b = byteTensor([True])
self.assertEqual(a[b], a)
# boolean assignment
a[b] = 1.
self.assertEqual(a, tensor([[1., 1., 1.]]))
def test_boolean_assignment_value_mismatch(self):
# A boolean assignment should fail when the shape of the values
# cannot be broadcast to the subscription. (see also gh-3458)
a = Variable(torch.arange(0, 4))
def f(a, v):
a[a > -1] = tensor(v)
self.assertRaisesRegex(Exception, "expand", f, a, [])
self.assertRaisesRegex(Exception, 'expand', f, a, [1, 2, 3])
self.assertRaisesRegex(Exception, 'expand', f, a[:1], [1, 2, 3])
def test_boolean_indexing_twodim(self):
# Indexing a 2-dimensional array with
# 2-dimensional boolean array
a = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
b = byteTensor([[True, False, True],
[False, True, False],
[True, False, True]])
self.assertEqual(a[b], tensor([1, 3, 5, 7, 9]))
self.assertEqual(a[b[1]], tensor([[4, 5, 6]]))
self.assertEqual(a[b[0]], a[b[2]])
# boolean assignment
a[b] = 0
self.assertEqual(a, tensor([[0, 2, 0],
[4, 0, 6],
[0, 8, 0]]))
def test_everything_returns_views(self):
# Before `...` would return a itself.
a = tensor(5)
self.assertIsNot(a, a[()])
self.assertIsNot(a, a[...])
self.assertIsNot(a, a[:])
def test_broaderrors_indexing(self):
a = zeros(5, 5)
self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
def test_trivial_fancy_out_of_bounds(self):
a = zeros(5)
ind = ones(20).long()
ind[-1] = 10
self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
ind = ones(20).long()
ind[0] = 11
self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
def test_index_is_larger(self):
# Simple case of fancy index broadcasting of the index.
a = zeros((5, 5))
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2, 3, 4])
self.assertTrue((a[:3, :3] == tensor([2, 3, 4])).all())
def test_broadcast_subspace(self):
a = zeros((100, 100))
v = Variable(torch.arange(0, 100))[:, None]
b = Variable(torch.arange(99, -1, -1).long())
a[b] = v
expected = b.double().unsqueeze(1).expand(100, 100)
self.assertEqual(a, expected)
if __name__ == '__main__':
run_tests()