Change requires_grad default to False
diff --git a/test/test_autograd.py b/test/test_autograd.py
index e5af351..4aa06d9 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -38,8 +38,8 @@
class TestAutograd(TestCase):
def test_hooks(self):
- x = Variable(torch.ones(5, 5))
- y = Variable(torch.ones(5, 5) * 4)
+ x = Variable(torch.ones(5, 5), requires_grad=True)
+ y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
counter = [0]
def bw_hook(inc, grad):
@@ -65,10 +65,10 @@
y_t = torch.rand(5, 5) + 0.1
z_t = torch.randn(5, 5)
grad_output = torch.randn(5, 5)
- v = Variable(v_t)
- x = Variable(x_t)
- y = Variable(y_t)
- z = Variable(z_t)
+ v = Variable(v_t, requires_grad=True)
+ x = Variable(x_t, requires_grad=True)
+ y = Variable(y_t, requires_grad=True)
+ z = Variable(z_t, requires_grad=True)
v.backward(grad_output)
self.assertEqual(v.grad, grad_output)
@@ -83,7 +83,7 @@
self.assertEqual(z.grad, z_grad * grad_output)
def test_volatile(self):
- x = Variable(torch.ones(5, 5))
+ x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, volatile=True)
z = x ** 2
@@ -111,8 +111,8 @@
self.assertEqual(x[1, 2:], y[1, 2:].data)
def test_inplace(self):
- x = Variable(torch.ones(5, 5))
- y = Variable(torch.ones(5, 5) * 4)
+ x = Variable(torch.ones(5, 5), requires_grad=True)
+ y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
z = x * y
q = z + y
@@ -396,12 +396,12 @@
call_args = (call_args,)
def map_arg(arg):
if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
- return Variable(torch.randn(*arg).double())
+ return Variable(torch.randn(*arg).double(), requires_grad=True)
elif torch.is_tensor(arg):
if isinstance(arg, torch.FloatTensor):
- return Variable(arg.double())
+ return Variable(arg.double(), requires_grad=True)
else:
- return Variable(arg)
+ return Variable(arg, requires_grad=True)
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
diff --git a/test/test_nn.py b/test/test_nn.py
index a4f59e6..d13f065 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -20,7 +20,7 @@
if isinstance(i, Variable):
return i
elif torch.is_tensor(i):
- return Variable(i)
+ return Variable(i, requires_grad=True)
else:
return type(i)(map_variables(elem) for elem in i)
return map_variables(input)
@@ -170,7 +170,7 @@
def test_hooks(self):
module = nn.Sigmoid()
- input = Variable(torch.ones(5, 5))
+ input = Variable(torch.ones(5, 5), requires_grad=True)
counter = {
'forwards': 0,
@@ -258,14 +258,14 @@
input.fill_(1-p)
module = cls(p)
- input_var = Variable(input)
+ input_var = Variable(input, requires_grad=True)
output = module(input_var)
self.assertLess(abs(output.data.mean() - (1-p)), 0.05)
output.backward(input)
self.assertLess(abs(input_var.grad.mean() - (1-p)), 0.05)
module = cls(p, True)
- input_var = Variable(input.clone())
+ input_var = Variable(input.clone(), requires_grad=True)
output = module(input_var + 0)
self.assertLess(abs(output.data.mean() - (1-p)), 0.05)
output.backward(input)
@@ -381,7 +381,7 @@
module = module_cls(2, return_indices=True)
numel = 4 ** num_dim
input = torch.range(1, numel).view(1, 1, *repeat(4, num_dim))
- input_var = Variable(input)
+ input_var = Variable(input, requires_grad=True)
# Check forward
output, indices = module(input_var)
diff --git a/test/test_utils.py b/test/test_utils.py
index 5c38dda..781f1af 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -58,7 +58,7 @@
class ModelMock(object):
def __init__(self):
self.num_calls = 0
- self.output = Variable(torch.ones(1, 1))
+ self.output = Variable(torch.ones(1, 1), requires_grad=True)
def __call__(self, i):
self.num_calls += 1
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index 1045f0b..c0ee00a 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -20,7 +20,8 @@
'is_cuda',
]
- def __init__(self, tensor, creator=None, volatile=False, requires_grad=True):
+ def __init__(self, tensor, creator=None, volatile=False,
+ requires_grad=False):
self.creator = creator
self.volatile = volatile
self.dirty = False
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index fdc7532..83dc2d8 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -323,7 +323,7 @@
def __init__(self, num_parameters=1, init=0.25):
self.num_parameters = num_parameters
super(PReLU, self).__init__(
- weight=Variable(torch.Tensor(num_parameters).fill_(init))
+ weight=torch.Tensor(num_parameters).fill_(init)
)
def forward(self, input):
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index cbb3fe3..4670b96 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -14,8 +14,8 @@
weight = bias = None
if self.affine:
- weight = Variable(torch.Tensor(num_features))
- bias = Variable(torch.Tensor(num_features))
+ weight = torch.Tensor(num_features)
+ bias = torch.Tensor(num_features)
super(_BatchNorm, self).__init__(weight=weight, bias=bias)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py
index 3a4ef2f..b93574a 100644
--- a/torch/nn/modules/conv.py
+++ b/torch/nn/modules/conv.py
@@ -46,8 +46,8 @@
kernel_elements = self.in_features * self.kernel_size
super(Conv1d, self).__init__(
- weight = Variable(torch.Tensor(out_features, in_features, kernel_size)),
- bias = Variable(torch.Tensor(out_features))
+ weight = torch.Tensor(out_features, in_features, kernel_size),
+ bias = torch.Tensor(out_features)
)
self.reset_parameters()
@@ -119,9 +119,9 @@
self.dilh, self.dilw = _pair(dilation)
self.groups = groups
- weight = Variable(torch.Tensor(
- self.out_channels, self.in_channels, self.kh, self.kw))
- bias = None if no_bias else Variable(torch.Tensor(self.out_channels))
+ weight = torch.Tensor(self.out_channels, self.in_channels, self.kh,
+ self.kw)
+ bias = None if no_bias else torch.Tensor(self.out_channels)
super(Conv2d, self).__init__(
weight=weight,
bias=bias,
@@ -244,9 +244,9 @@
padding=0):
super(Conv3d, self).__init__(in_channels, out_channels, kernel_size,
stride, padding)
- weight = Variable(torch.Tensor(self.out_channels,
- self.in_channels, self.kt, self.kh, self.kw))
- bias = Variable(torch.Tensor(self.out_channels))
+ weight = torch.Tensor(self.out_channels, self.in_channels, self.kt,
+ self.kh, self.kw)
+ bias = torch.Tensor(self.out_channels)
Module.__init__(self, weight=weight, bias=bias)
self.reset_parameters()
@@ -291,9 +291,9 @@
padding=0):
super(FullConv3d, self).__init__(in_channels, out_channels, kernel_size,
stride, padding)
- weight = Variable(torch.Tensor(self.in_channels,
- self.out_channels, self.kt, self.kh, self.kw))
- bias = Variable(torch.Tensor(self.out_channels))
+ weight = torch.Tensor(self.in_channels, self.out_channels, self.kt,
+ self.kh, self.kw)
+ bias = torch.Tensor(self.out_channels)
Module.__init__(self, weight=weight, bias=bias)
self.reset_parameters()
diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py
index 4e90a37..d3befeb 100644
--- a/torch/nn/modules/linear.py
+++ b/torch/nn/modules/linear.py
@@ -30,8 +30,8 @@
self.out_features = out_features
super(Linear, self).__init__(
- weight=Variable(torch.Tensor(out_features, in_features)),
- bias=Variable(torch.Tensor(out_features))
+ weight=torch.Tensor(out_features, in_features),
+ bias=torch.Tensor(out_features)
)
self.reset_parameters()
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 62dba4f..538c626 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -15,6 +15,10 @@
self.backward_hooks = OrderedDict()
self.forward_hooks = OrderedDict()
self.train = True
+ for name, param in self._parameters.items():
+ if param is not None and not isinstance(param, Variable):
+ param = Variable(param, requires_grad=True)
+ self._parameters[name] = param
def forward(self, *input):
raise NotImplementedError
diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py
index 69a6653..029cfab 100644
--- a/torch/nn/modules/sparse.py
+++ b/torch/nn/modules/sparse.py
@@ -35,7 +35,7 @@
self.scale_grad_by_freq = scale_grad_by_freq
super(Embedding, self).__init__(
- weight=Variable(torch.Tensor(num_embeddings, embedding_dim))
+ weight=torch.Tensor(num_embeddings, embedding_dim)
)
self.reset_parameters()
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index 2ca6b6e..292efc8 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -41,6 +41,8 @@
def _forward_backward(self, forward_closure):
for group in self.param_groups:
for p in group['params']:
+ assert p.requires_grad, "optimizing a parameter that doesn't " \
+ "require gradients"
p.grad.zero_()
loss = forward_closure()
loss.backward()