Revert D17513451: Register values listed in __constants__ as attributes of the Module.
Test Plan: revert-hammer
Differential Revision:
D17513451
Original commit changeset: cf8f9b450e71
fbshipit-source-id: 319ec9399173eb06556969dc6be365b319c1ab6c
diff --git a/test/test_jit.py b/test/test_jit.py
index 401bf88..5682b72 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -14573,69 +14573,6 @@
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
self.assertEqual(m("c"), torch.tensor([103]))
- def test_module_const_attrs(self):
- class M(torch.nn.Module):
- __constants__ = ['i', 'non']
-
- def __init__(self):
- super(M, self).__init__()
- self.i = 10
- self.non = None
-
- def forward(self, x):
- # type: (int) -> int
- if self.non is None:
- return x + self.i
- else:
- return -1
-
- m = torch.jit.script(M())
- self.assertEqual(m(5), 15)
- self.assertTrue(m._c._has_attribute('i'))
- self.assertFalse(m._c._has_attribute('non'))
-
- def test_module_mutate_const_attrs(self):
- # Check that we cannot mutate a constant
- class M(torch.nn.Module):
- __constants__ = ['i', 'non']
-
- def __init__(self):
- super(M, self).__init__()
- self.i = 10
- self.non = None
-
- def forward(self, x):
- # type: (int) -> int
- self.i = 5
- if self.non is None:
- return x + self.i
- else:
- return -1
-
- with self.assertRaises(RuntimeError):
- m = torch.jit.script(M())
-
- def test_module_mutate_const_attrs_2(self):
- # Check that we cannot mutate constant of a mutable type (e.g. list)
- class M(torch.nn.Module):
- __constants__ = ['i', 'non']
-
- def __init__(self):
- super(M, self).__init__()
- self.i = [10, 20]
- self.non = None
-
- def forward(self, x):
- # type: (int) -> int
- self.i.append(30)
- if self.non is None:
- return x + self.i[0]
- else:
- return -1
-
- with self.assertRaises(RuntimeError):
- m = torch.jit.script(M())
-
def test_tensor_import_export(self):
@torch.jit.script
def foo(x):
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index 1e79465..611757a 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -360,9 +360,7 @@
m.graph()->insertGetAttr(self_, field),
*v,
py_module_.attr(field.c_str()));
- }
-
- if (auto kind = module_.kind_of(field)) {
+ } else if (auto kind = module_.kind_of(field)) {
// methods, parameters, attributes, and buffers are all first class
return SimpleValue(self_).attr(loc, m, field);
}
@@ -449,24 +447,6 @@
}
}
- if (py_module_.attr("_constants_set").contains(field.c_str())) {
- // Values of the attributes listed in the _constants_set will be put
- // directly into IR. In order to allow us to access these values by their
- // name after IR is generated, we register them as attributes.
- if (py::isinstance<py::bool_>(attr)) {
- module_.register_attribute(field, BoolType::get(), py::cast<bool>(attr));
- } else if (py::isinstance<py::int_>(attr)) {
- module_.register_attribute(
- field, IntType::get(), py::cast<int64_t>(attr));
- } else if (py::isinstance<py::float_>(attr)) {
- module_.register_attribute(
- field, FloatType::get(), py::cast<double>(attr));
- } else if (py::isinstance<py::str>(attr)) {
- module_.register_attribute(
- field, StringType::get(), py::cast<std::string>(attr));
- }
- }
-
if (py::isinstance<py::function>(attr) ||
py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
py_module_.attr("_constants_set").contains(field.c_str())) {