Allow 'Any' to appear as a type argument. (#26572)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26572
Combined with isinstance specialization this allows a degree of polymorphic
functions to work without needing to use our weirder overload hacks.
We do not define any operators on Any, so the only thing you can do with it
is to put it in containers or type refine it using an isinstance check.
Any is restricted from appearing in non-argument position because we
cannot restore type tags if it ends up as a field in a class.
Test Plan: Imported from OSS
Differential Revision: D17530643
Pulled By: zdevito
fbshipit-source-id: f06f78ce84819f7773953a492f3d4c49219ee94c
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 5792ade..498f4bc 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -782,6 +782,19 @@
c10::optional<QualifiedName> name_;
};
+// Any should never appear in a named type like a class, namedtuple or
+// interface. If it does, then dynamic type information will be lost in the
+// Pickler, leading to hard-to-track-down bugs that will only occur
+// after saving or loading a model. This is because we rely on the
+// static types in named types to reconstruct type tags of loaded
+// values. Lifting this restriction requires solving the serialization
+// problem first.
+CAFFE2_API void checkNoAny(
+ const Type& base,
+ const char* what,
+ const std::string& attrname,
+ const TypePtr& attrtype);
+
struct TupleType;
using TupleTypePtr = std::shared_ptr<TupleType>;
using NameList = std::vector<std::string>;
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index 1bb7959..97cc911 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -583,6 +583,11 @@
std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
return v->hasFreeVariables();
});
+ if (schema_) {
+ for (const Argument& arg : schema_->arguments()) {
+ checkNoAny(*this, "attribute", arg.name(), arg.type());
+ }
+ }
}
bool TupleType::isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const {
@@ -718,4 +723,34 @@
InterfaceType::~InterfaceType() = default;
+
+static bool containsAny(const TypePtr& type) {
+ std::vector<TypePtr> to_scan = { type };
+ while (!to_scan.empty()) {
+ TypePtr typ = to_scan.back();
+ to_scan.pop_back();
+ if (typ->kind() == AnyType::Kind) {
+ return true;
+ }
+ for (const TypePtr& sub : typ->containedTypes()) {
+ to_scan.emplace_back(sub);
+ }
+ }
+ return false;
+}
+
+void checkNoAny(const Type& base, const char* what, const std::string& attrname, const TypePtr& attrtype) {
+ TORCH_CHECK(
+ !containsAny(attrtype),
+ "attempting to add ",
+ what,
+ " '",
+ attrname,
+ "' of type ",
+ attrtype->python_str(),
+ " to '",
+ base.python_str(),
+ "' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples.");
+}
+
} // namespace c10
diff --git a/test/cpp/jit/test_autodiff.cpp b/test/cpp/jit/test_autodiff.cpp
index 6bde65e..b9abd5f 100644
--- a/test/cpp/jit/test_autodiff.cpp
+++ b/test/cpp/jit/test_autodiff.cpp
@@ -66,8 +66,7 @@
auto input_typeptr = TupleType::create(std::move(input_types));
std::shared_ptr<tracer::TracingState> state;
Stack trace_stack_in;
- std::tie(state, trace_stack_in) =
- tracer::enter(tracer::TypedStack(input_vars, input_typeptr));
+ std::tie(state, trace_stack_in) = tracer::enter(input_vars);
variable_list trace_vars_in = fmap(
trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); });
auto trace_vars_out = test(trace_vars_in);
diff --git a/test/test_jit.py b/test/test_jit.py
index 60f79f5..d32ffe6 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -14,7 +14,7 @@
from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes
from torch._six import inf, PY2, PY37, StringIO
from torch.autograd import Variable, Function
-from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401
from torch.jit.frontend import NotSupportedError
from torch.onnx import OperatorExportTypes
from torch.testing import FileCheck
@@ -6987,6 +6987,27 @@
self.assertEqual(foo2(None, 4), 0)
self.assertEqual(foo2(4, None), 0)
+ @torch.jit.script
+ def any_refinement(a, b):
+ # type: (Any, Any) -> int
+ if isinstance(a, int) and isinstance(b, int):
+ return a + b
+ return 0
+
+ self.assertEqual(any_refinement(3, 4), 7)
+ self.assertEqual(any_refinement(3, "hi"), 0)
+
+ def test_any_in_class_fails(self):
+ with self.assertRaisesRegex(RuntimeError, "contains an Any"):
+ @torch.jit.script
+ class Foo(object):
+ def __init__(self, a):
+ # type: (Tuple[int,Any]) -> None
+ self.a = a
+
+ def hi(self):
+ pass
+
def test_isinstance(self):
# test isinstance operator for static type checking
template = dedent('''
diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py
index 41d28d4..7b39821 100644
--- a/test/test_jit_py3.py
+++ b/test/test_jit_py3.py
@@ -1,7 +1,7 @@
from common_utils import run_tests
from jit_utils import JitTestCase
from torch.testing import FileCheck
-from typing import NamedTuple, List, Optional
+from typing import NamedTuple, List, Optional, Any
import unittest
import sys
import torch
@@ -230,5 +230,17 @@
x : Optional[int] = 7
+ def test_any_in_class_fails(self):
+ class MyCoolNamedTuple(NamedTuple):
+ a : Any
+ b : float
+ c : List[int]
+ with self.assertRaisesRegex(RuntimeError, "contains an Any"):
+ @torch.jit.script
+ def foo():
+ return MyCoolNamedTuple(4, 5.5, [3])
+ print(foo.graph)
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 33ea953..9c4c65b 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -521,7 +521,7 @@
try:
import typing
- from typing import Tuple, List, Dict, Optional
+ from typing import Tuple, List, Dict, Optional, Any
def is_tuple(ann):
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
@@ -607,10 +607,14 @@
def __getitem__(self, types):
return OptionalInstance(types)
+ class AnyCls(object):
+ pass
+
Tuple = TupleCls() # noqa: T484
List = ListCls() # noqa: T484
Dict = DictCls() # noqa: T484
Optional = DictCls() # noqa: T484
+ Any = AnyCls() # noqa: T484
def is_tuple(ann):
return isinstance(ann, TupleInstance)
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 3dfd066..71eb5f0 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -243,7 +243,7 @@
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
- stack.push_back(toIValue(obj));
+ stack.push_back(toTypeInferredIValue(obj));
}
ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
arg_spec_creator.specializeTypes(*graph, spec);
@@ -314,7 +314,7 @@
[](std::shared_ptr<Graph> g,
py::tuple args,
const std::string& unqualified_op_name) {
- auto stack = toStack(args);
+ auto stack = toTraceableStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
})
.def(
@@ -535,7 +535,7 @@
// Convert the output of the user-supplied funciton to IValue. The type
// information of this IValue is used both to record the correct type in
// the trace.
- output_ivalue = toIValue(py_func_output);
+ output_ivalue = toTypeInferredIValue(py_func_output);
Value* out_val = jit::tracer::getValueTrace(output_ivalue);
body_block->registerOutput(out_val);
node_output =
@@ -556,7 +556,7 @@
return PythonFutureWrapper(retval);
} else {
- auto result = toIValue(f(*args_tup));
+ auto result = toTypeInferredIValue(f(*args_tup));
auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
retval->markCompleted(std::move(result));
return PythonFutureWrapper(retval);
diff --git a/torch/csrc/jit/pybind.h b/torch/csrc/jit/pybind.h
index e1b7c3b..a6cefa1 100644
--- a/torch/csrc/jit/pybind.h
+++ b/torch/csrc/jit/pybind.h
@@ -27,7 +27,7 @@
bool load(handle src, bool) {
try {
- value = torch::jit::toIValue(src);
+ value = torch::jit::toTypeInferredIValue(src);
return true;
} catch (std::exception& e) {
return false;
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index 58b87c7..14d33b1 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -44,8 +44,6 @@
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.
-using tracer::TypedStack;
-
inline std::shared_ptr<script::CompilationUnit> get_python_cu() {
return py::module::import("torch.jit")
.attr("_python_cu")
@@ -287,37 +285,24 @@
return false;
}
-inline TypedIValue toTraceableIValue(py::handle input) {
+inline IValue toTypeInferredIValue(py::handle input) {
auto match = tryToInferType(input);
if (!match.success()) {
AT_ERROR(
"Tracer cannot infer type of ", py::str(input), "\n:", match.reason());
}
- auto type = match.type();
+ return toIValue(input, match.type());
+}
- if (isTraceableType(type)) {
- return TypedIValue(toIValue(input, type), type);
- }
-
- AT_ERROR(
+inline Stack toTraceableStack(const py::tuple& inputs) {
+ auto info = toTypeInferredIValue(inputs);
+ AT_CHECK(
+ isTraceableType(info.type()),
"Type '",
- type->python_str(),
+ info.type()->python_str(),
"' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
" Tuples of Tensors can be traced");
-}
-
-inline IValue toIValue(py::handle input) {
- return toTraceableIValue(input).ivalue();
-}
-
-inline Stack toStack(const py::tuple& inputs) {
- return toIValue(inputs).toTuple()->elements();
-}
-
-inline TypedStack toTypedStack(const py::tuple& inputs) {
- auto info = toTraceableIValue(inputs);
- return TypedStack(
- info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
+ return info.toTuple()->elements();
}
inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
@@ -545,7 +530,7 @@
case TypeKind::CapsuleType:
AT_ERROR("Capsule Values aren't supported");
case TypeKind::AnyType:
- AT_ERROR("AnyType Values aren't supported");
+ return toTypeInferredIValue(obj);
}
AT_ERROR(
"Missing cases in toIValue for type: ",
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 921e24f..e1b2dac 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -678,6 +678,8 @@
return self->isSubtypeOf(other);
});
+ py::class_<AnyType, Type, std::shared_ptr<AnyType>>(m, "AnyType")
+ .def_static("get", &AnyType::get);
py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
.def_static("get", &NumberType::get);
py::class_<IntType, Type, std::shared_ptr<IntType>>(m, "IntType")
diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp
index a05e028..bd78d85 100644
--- a/torch/csrc/jit/python_tracer.cpp
+++ b/torch/csrc/jit/python_tracer.cpp
@@ -51,7 +51,7 @@
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
const py::function& func,
- TypedStack trace_inputs,
+ Stack trace_inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
script::Module* self) {
@@ -78,7 +78,7 @@
"The traced function didn't return any values! Side-effects are not "
"captured in traces, so it would be a no-op.");
}
- tracer::exit({toIValue(out)});
+ tracer::exit({toTypeInferredIValue(out)});
if (script::getInlineEverythingMode()) {
Inline(*graph);
}
@@ -161,10 +161,10 @@
m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
m.def("_tracer_enter", [](py::args trace_inputs) {
- return tracer::enter(toTypedStack(trace_inputs));
+ return tracer::enter(toTraceableStack(trace_inputs));
});
m.def("_tracer_exit", [](py::tuple var_outputs) {
- tracer::exit(toStack(var_outputs));
+ tracer::exit(toTraceableStack(var_outputs));
});
m.def("_tracer_abandon", []() { tracer::abandon(); });
m.def("_get_tracing_state", []() { return getTracingState(); });
diff --git a/torch/csrc/jit/python_tracer.h b/torch/csrc/jit/python_tracer.h
index 24be865..d865017 100644
--- a/torch/csrc/jit/python_tracer.h
+++ b/torch/csrc/jit/python_tracer.h
@@ -29,7 +29,7 @@
std::shared_ptr<Graph> createGraphByTracing(
const py::function& func,
- TypedStack inputs,
+ Stack inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
script::Module* self = nullptr);
diff --git a/torch/csrc/jit/script/class_type.cpp b/torch/csrc/jit/script/class_type.cpp
index 3c1819a..210c7b9 100644
--- a/torch/csrc/jit/script/class_type.cpp
+++ b/torch/csrc/jit/script/class_type.cpp
@@ -65,17 +65,21 @@
const std::string& name,
TypePtr type,
bool is_parameter) {
+ const char* what = is_parameter ? "parameter" : "attribute";
for (size_t i = 0; i < attributeNames_.size(); ++i) {
TORCH_CHECK(
name != attributeNames_[i],
"attempting to add ",
- is_parameter ? "parameter"
- : "attribute"
- " '",
+ what,
+ " '",
name,
- "' but a field of the same name already exists with type ",
+ "' to ",
+ python_str(),
+ " but a field of the same name already exists with type ",
attributeTypes_[i]->python_str());
}
+ checkNoAny(*this, what, name, type);
+
size_t slot = attributeNames_.size();
attributeNames_.push_back(name);
attributeTypes_.push_back(type);
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 54bc4c9..19b493e 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -599,7 +599,7 @@
bool force_outplace) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
- auto typed_inputs = toTypedStack(input_tuple);
+ auto typed_inputs = toTraceableStack(input_tuple);
auto graph = tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace, &self);
const auto method_name = QualifiedName(self.name(), name);
@@ -805,7 +805,7 @@
py::tuple input_tuple,
py::function var_lookup_fn,
bool force_outplace) {
- auto typed_inputs = toTypedStack(input_tuple);
+ auto typed_inputs = toTraceableStack(input_tuple);
auto graph = tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace);
auto cu = get_python_cu();
diff --git a/torch/csrc/jit/script/script_type_parser.cpp b/torch/csrc/jit/script/script_type_parser.cpp
index b32d7c2..6b8c2d5 100644
--- a/torch/csrc/jit/script/script_type_parser.cpp
+++ b/torch/csrc/jit/script/script_type_parser.cpp
@@ -19,6 +19,7 @@
// parsing serialized methods that use implicit converions to Scalar
{"number", NumberType::get()},
{"None", NoneType::get()},
+ {"Any", AnyType::get()},
};
return map;
}
@@ -269,21 +270,20 @@
auto decl_arg = *it;
TypePtr type;
- c10::optional<int32_t> N;
+ c10::optional<int32_t> N = c10::nullopt;
bool is_inferred_type = false;
if (!decl_arg.type().present()) {
// If this param doesn't have a type, default to "tensor"
is_inferred_type = true;
type = TensorType::get();
- N = c10::nullopt;
} else {
// BroadcastList list can only appear at the argument level
- if (auto maybe_broad_list = parseBroadcastList(decl_arg.type().get())) {
+ Expr type_expr = decl_arg.type().get();
+ if (auto maybe_broad_list = parseBroadcastList(type_expr)) {
type = maybe_broad_list->first;
N = maybe_broad_list->second;
} else {
type = parseTypeFromExpr(decl_arg.type().get());
- N = c10::nullopt;
}
}
c10::optional<IValue> default_value = c10::nullopt;
diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp
index f5e425e..0aa64a3 100644
--- a/torch/csrc/jit/script/sugared_value.cpp
+++ b/torch/csrc/jit/script/sugared_value.cpp
@@ -212,6 +212,7 @@
<< "Classes that recursively contain instances of themselves"
<< " are not yet supported";
}
+
classType->addAttribute(field, newValue->type());
expectedType = newValue->type();
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 7fecd1a..53b3f70 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -146,7 +146,7 @@
}
return it->second;
}
- std::ostringstream oss;
+ std::ostringstream oss;
if (var.isFuture()) {
oss << "Tried to trace Future or Object that the tracer was not aware of.";
} else {
@@ -285,7 +285,7 @@
Value* self_value,
const script::Module& self) {
Graph& g = *self_value->owningGraph();
-
+
state->setValue(self.module_object(), self_value);
for (script::Slot s : self.get_slots()) {
@@ -304,7 +304,7 @@
// varied on subsequent invocations of the trace. Any other variables
// will be treated as constants.
std::pair<std::shared_ptr<TracingState>, Stack> enter(
- TypedStack inputs,
+ Stack inputs,
script::Module* self) {
if (isTracing()) {
AT_ERROR("Tracing can't be nested");
@@ -321,12 +321,10 @@
}
size_t i = 0;
- auto input_types = inputs.types()->elements();
- for (IValue& input : inputs.stack()) {
- input = addInput(state,
- input, input_types[i++], state->graph->addInput());
+ for (IValue& input : inputs) {
+ input = addInput(state, input, input.type(), state->graph->addInput());
}
- return std::make_pair(state, inputs.stack());
+ return std::make_pair(state, inputs);
}
// Exit a trace, treating 'outputs' as the outputs of the trace. These
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index a8685cc..30bb865 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -205,31 +205,8 @@
TORCH_API Value* getValueTrace(const IValue& var);
-struct TypedStack : public std::pair<Stack, TupleTypePtr>
-{
- using pair::pair;
-
- // NB: The inherited default constructor gives nullptr for |type|,
- // so we provide a saner one.
- TypedStack()
- : pair({}, TupleType::create({}))
- {}
-
- Stack& stack() {
- return this->first;
- }
- TupleTypePtr& types() {
- return this->second;
- }
- size_t size() {
- auto s = stack().size();
- AT_ASSERT(s == types()->elements().size());
- return s;
- }
-};
-
TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(
- TypedStack inputs,
+ Stack inputs,
script::Module* self = nullptr);
TORCH_API void exit(const Stack& outputs);
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index e5c91f4..120dfbb 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -5,9 +5,10 @@
import torch
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
- is_optional, _qualified_name
+ is_optional, _qualified_name, Any
from torch._C import TensorType, TupleType, FloatType, IntType, \
- ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType
+ ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType
+
from textwrap import dedent
from torch._six import builtins
from torch._utils_internal import get_source_lines_and_file
@@ -28,15 +29,6 @@
raise RuntimeError("Module {} has no member called {}".format(self.name, name))
-_eval_env = {
- 'torch': Module('torch', {'Tensor': torch.Tensor}),
- 'Tensor': torch.Tensor,
- 'typing': Module('typing', {'Tuple': Tuple}),
- 'Tuple': Tuple,
- 'List': List,
- 'Dict': Dict,
- 'Optional': Optional,
-}
class EvalEnv(object):
env = {
'torch': Module('torch', {'Tensor': torch.Tensor}),
@@ -244,6 +236,8 @@
return StringType.get()
elif ann is bool:
return BoolType.get()
+ elif ann is Any:
+ return AnyType.get()
elif hasattr(ann, "__torch_script_class__"):
return ClassType(_qualified_name(ann))
elif hasattr(ann, "__torch_script_interface__"):
@@ -258,6 +252,7 @@
__all__ = [
+ 'Any',
'List',
'BroadcastingList1',
'BroadcastingList2',
@@ -274,6 +269,7 @@
'ListType',
'StringType',
'DictType',
+ 'AnyType',
'Module',
# TODO: Consider not exporting these during wildcard import (reserve
# that for the types; for idiomatic typing code.)