blob: d22f8864a66842cc7a2d190f235ec329c0e9448c [file] [log] [blame]
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/script/parser.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/script/builtin_functions.h"
#include "torch/csrc/jit/hooks_for_testing.h"
#include "torch/csrc/jit/constants.h"
#include "c10/util/Optional.h"
#include <climits>
#include <set>
namespace torch {
namespace jit {
namespace script {
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
using FunctionTable = std::unordered_map<std::string, Method&>;
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
using AttributeMap = std::unordered_map<std::string, Const>;
using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
struct NoneValue : SugaredValue {
NoneValue() = default;
std::string kind() const override {
return "None";
}
};
// matched against for special handling of getattr expressions
struct GetAttrValue : SugaredValue {
std::string kind() const override {
return "getattr";
}
};
struct PrintValue : public SugaredValue {
std::string kind() const override {
return "print";
}
std::shared_ptr<SugaredValue> call(
SourceRange loc,
Method & m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
auto& g = *m.graph();
if (!attributes.empty())
throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
//temporary hack to allow print statements to work in python 2, where
//print(a, b) is treated as a (a, b) tuple input.
std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
if(lowered_inputs.size() == 1 && lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
auto input = lowered_inputs[0];
for(size_t j = 0; j < input->node()->inputs().size(); ++j) {
lowered_inputs.insert(lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
}
lowered_inputs.erase(lowered_inputs.begin());
}
g.insertNode(g.create(prim::Print, lowered_inputs, 0)
->setSourceLocation(std::make_shared<SourceRange>(loc)));
return std::make_shared<NoneValue>();
}
};
static Value* typeCast(const SourceRange& loc, Value* value, TypePtr dst) {
auto& graph = *value->owningGraph();
const TypePtr orig = value->type();
Node* n = nullptr;
if(dst->isSubtypeOf(DynamicType::get()) && orig->isSubtypeOf(NumberType::get())) {
n = graph.createNumToTensor(value);
} else if (dst->isSubtypeOf(NumberType::get()) && orig->isSubtypeOf(DynamicType::get())) {
n = graph.createTensorToNum(dst, value);
} else if (dst->isSubtypeOf(BoolType::get()) && orig->isSubtypeOf(DynamicType::get())) {
n = graph.createTensorToBool(value);
} else if(dst->isSubtypeOf(IntType::get()) && orig->isSubtypeOf(FloatType::get())) {
n = graph.createFloatToInt(value);
} else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(IntType::get())) {
n = graph.createIntToFloat(value);
} else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(StringType::get())) {
n = graph.createStringToFloat(value);
} else {
throw ErrorReport(loc) << "Cannot cast type '" << orig->str() << "' to type '"
<< dst->str() << "'.";
}
auto* result = graph.insertNode(n)
->setSourceLocation(std::make_shared<SourceRange>(loc))
->output();
return result;
}
// expressions like int(x)
struct CastValue : public SugaredValue {
CastValue(TypePtr type)
: type(std::move(type)) {}
std::string kind() const override {
std::stringstream ss;
ss << "<" << type->str() << " cast primitive>";
return ss.str();
}
std::shared_ptr<SugaredValue> call(
SourceRange loc,
Method & m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (!attributes.empty())
throw ErrorReport(loc) << "casts do not accept any keyword arguments";
if (inputs.size() != 1)
throw ErrorReport(loc) << "expected a single argument for cast";
auto values = toValues(*m.graph(), inputs);
Value* input = values.at(0);
if(!input->type()->isSubtypeOf(type)) {
input = typeCast(loc, input, type);
}
return std::make_shared<SimpleValue>(input);
}
private:
TypePtr type;
};
// we consider _N where N is a number, to be a non-meaningful name
// and do not record it as a unique name. This allows python printing to
// be able to export and import more consistently named graphs
static bool meaningfulName(const std::string& name) {
if (name.size() == 0)
return false;
if (name[0] != '_')
return true;
for (size_t i = 1; i < name.size(); ++i) {
if (!isdigit(name[i]))
return true;
}
return false;
}
// Auxiliary data structure for desugaring variable binding into our always
// explicitly scoped language as we descend down
// nested control structures in the frontend (which themselves don't introduce
// scopes)
//
// The algorithm is roughly as follows:
// 1) While emitting a block within a control operator, add inputs and outputs
// from the block for each value referenced (both "reads" and "writes").
// This sets the value up as a candidate loop carried dependency.
// 2) When we reach the end of the block, examine all the values in the current
// scope's value map. If the name also resides in an outer scope with a
// different Value*, this is a true loop-carried dependency. If not, this
// value was not assigned to. Replace all references to the block input
// with the Value* pointed to in the tightest enclosing scope. Then delete
// that block input and output.
// 3) When we emit the actual control operator, take all of the loop-carried
// dependency values as inputs and return them as outputs from the control
// op
//
// Note that an alternative implementation could only add the loop-carried dep
// inputs and outputs when we see a value that is mutated. This, however
// requires replacing all references to that value *within the current
// block* with a new input. That is to say: we need to traverse the pre-
// decessor nodes and replace inputs that reference that value with the
// newly-created input. This could be made less expensive with a change to
// the IR API, but for now we choose to pessimisitically create inputs and
// delete unnecessary ones later with replaceAllusesWith().
struct Environment {
Environment(Method & method, Resolver resolver, Block* b, std::shared_ptr<Environment> next = nullptr)
: method(method), resolver(std::move(resolver)), b(b), next(std::move(next)) {}
Method & method;
Resolver resolver;
std::vector<std::string> captured_inputs;
std::unordered_map<std::string, std::string> error_messages;
Block* b;
std::shared_ptr<Environment> next;
// set type error in the lowest environment. if the variable is used after an
// error has been set, then we will use the more informative error message
void setVariableTypeError(const std::string& name, const std::string &msg) {
auto runner = this;
while (runner->next) {
runner = runner->next.get();
}
runner->error_messages[name] = msg;
}
// see if type error has been set for a variable
c10::optional<std::string> findVariableTypeError(const std::string& name) {
auto runner = this;
while (runner->next) {
runner = runner->next.get();
}
auto msg = runner->error_messages.find(name);
if (msg != runner->error_messages.end()) {
return msg->second;
} else {
return c10::nullopt;
}
}
SugaredValuePtr findInThisFrame(const std::string& name) {
auto it = value_table.find(name);
if (it != value_table.end()) {
return it->second;
}
return nullptr;
}
SugaredValuePtr findInParentFrame(const std::string& name) {
return next ? next->findInAnyFrame(name) : nullptr;
}
SugaredValuePtr findInAnyFrame(const std::string& name) {
for (auto runner = this; runner; runner = runner->next.get()) {
if(auto r = runner->findInThisFrame(name)) {
return r;
}
}
return nullptr;
}
Value* getValueInThisFrame(const SourceRange& loc, const std::string& name) {
return value_table.at(name)->asValue(loc, method);
}
SugaredValuePtr createCapturedInput(Value* orig, const std::string& name) {
// insert the captured input alphabetically in the capture list.
// this ensures consistency of the order of loop-carried dependencies
// even when the use in the loop is in a different order
size_t insert_pos = 0;
while (insert_pos < captured_inputs.size() && name > captured_inputs[insert_pos]) {
insert_pos++;
}
captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
// Create the input
const size_t loop_carried_block_inputs_offset = 1;
Value* new_input = b->insertInput(loop_carried_block_inputs_offset + insert_pos)
->setType(orig->type());
// Associate this name with this value
auto sv = std::make_shared<SimpleValue>(new_input);
value_table[name] = sv;
return sv;
}
SugaredValuePtr createCapturedInputIfNeeded(const SourceRange& loc, std::string ident) {
auto in_frame = findInThisFrame(ident);
if (in_frame) {
return in_frame;
}
// recursively handles the case where parent blocks are also loops
auto from_parent = next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
// recursively create the captured input if it is the loop block
if (from_parent && getBlockOwningKind() == prim::Loop) {
if (Value* simple_val = asSimple(from_parent))
from_parent = createCapturedInput(simple_val, ident);
}
return from_parent;
}
Block* block() {
return b;
}
Symbol getBlockOwningKind() {
Symbol owning_kind = Symbol();
if (b->owningNode()) {
owning_kind = b->owningNode()->kind();
}
return owning_kind;
}
void setVar(const SourceRange& loc, const std::string& name, Value* value) {
setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
}
static Value* asSimple(SugaredValuePtr value) {
if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
return sv->getValue();
}
return nullptr;
}
void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
Value* as_simple_value = asSimple(value);
if (as_simple_value && !as_simple_value->hasUniqueName() &&
meaningfulName(name) &&
// note: if the value wasn't defined in this block, we might be giving a name
// only used inside this block to a value outside of this. this is not
// normally helpful for debugging and causes import/export jitter.
as_simple_value->node()->owningBlock() == block()) {
as_simple_value->setUniqueName(name);
}
// prevent re-assignment involving any sugared values
// any reassignment like:
// a = ...
// while ...
// a = ..
// requires 'a' to be first-class in the graph since its value depends on
// control flow
if(auto parent = findInParentFrame(name)) {
if(!as_simple_value) {
throw ErrorReport(loc) << "Cannot re-assign '" << name << "' to a value of type " << value->kind() <<
" because " << name << " is not a first-class value. Only reassignments to first-class values are allowed";
}
Value* simple_parent = asSimple(parent);
if(!simple_parent) {
throw ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() <<
" and " << name << " is not a first-class value. Only reassignments to first-class values are allowed";
}
if (!as_simple_value->type()->isSubtypeOf(
unshapedType(simple_parent->type()))) {
std::stringstream errMsg;
errMsg << "variable '" << name << "' previously has type "
<< simple_parent->type()->str()
<< " but is now being assigned to a value of type "
<< as_simple_value->type()->str();
// Special-cased error msg if we're trying to assign to a tensor list.
if (simple_parent->type()->kind() == TypeKind::ListType &&
as_simple_value->type()->kind() == TypeKind::ListType) {
errMsg << "\n. (Note: empty lists are constructed as Tensor[]; "
<< "if you want an empty list of a different type, "
<< "use `torch.jit.annotate(List[T], [])`, "
<< "where `T` is the type of elements in the list)";
}
throw ErrorReport(loc) << errMsg.str();
}
}
if (as_simple_value)
createCapturedInputIfNeeded(loc, name);
value_table[name] = std::move(value);
}
SugaredValuePtr getSugaredVar(const Ident& ident, bool required=true) {
return getSugaredVar(ident.name(), ident.range());
}
Value* getVar(const Ident& ident) {
return getSugaredVar(ident)->asValue(ident.range(), method);
}
SugaredValuePtr getSugaredVar(const std::string& ident, SourceRange range, bool required=true) {
auto retval = createCapturedInputIfNeeded(range, ident);
if(!retval) {
static std::unordered_map<std::string, SugaredValuePtr> globals = {
{"print", std::make_shared<PrintValue>()},
{"float", std::make_shared<CastValue>(FloatType::get())},
{"int", std::make_shared<CastValue>(IntType::get())},
{"bool", std::make_shared<CastValue>(BoolType::get())},
{"getattr", std::make_shared<GetAttrValue>()},
// todo(zach): remove when we can correctly export torch.full via ONNX
// or we have implicit conversion that can convert numbers to tensors
{"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) },
};
auto it = globals.find(ident);
if(it != globals.end())
retval = it->second;
}
if(!retval) {
retval = resolver(ident, method, range);
}
if (!retval && required) {
// check if this value was not emitted in an if statement because of a
// type mismatch. if it was, then we print a more informative error msg
if (auto msg = findVariableTypeError(ident)) {
throw ErrorReport(range) << *msg << "and was used here";
}
throw ErrorReport(range) << "undefined value " << ident;
}
return retval;
}
Value* getVar(const std::string& ident, SourceRange range) {
return getSugaredVar(ident, range)->asValue(range, method);
}
// Given that after emitting statements in a block, we've added block inputs
// for all value references and assignments, delete inputs for which there was
// no assignment, only references.
void deleteExtraInputs(const SourceRange& loc) {
// note: skip i == 0, it is the loop trip count for inputs
// and the loop condition for outputs.
// captured_inputs is indexed by i - 1 since it only contains loop
// carried dependencies
// inputs: loop_counter, lcd0, lcd1, ...
// outputs: loop_condition, lcd0, lcd1, ...
// captured_inputs: lcd0, lcd1, ...
JIT_ASSERT(b->inputs().size() == b->outputs().size());
JIT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
for(size_t i = b->inputs().size() - 1; i > 0; i--) {
// nothing changed along this loop
if(b->inputs()[i] == b->outputs()[i]) {
auto name = captured_inputs[i - 1];
Value* orig = findInParentFrame(name)->asValue(loc, method);
b->inputs()[i]->replaceAllUsesWith(orig);
b->eraseInput(i);
b->eraseOutput(i);
captured_inputs.erase(captured_inputs.begin() + i - 1);
}
}
}
std::vector<std::string> definedVariables() {
std::vector<std::string> result;
for(auto & kv : value_table) {
result.push_back(kv.first);
}
return result;
}
private:
ValueTable value_table;
};
Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
if(values.size() == 1) {
return values[0];
}
return g.insertNode(g.createTuple(values))->output();
}
at::ArrayRef<Value*> createTupleUnpack(Value* v) {
// small peephole optimization to ensure IntList attributes can still turn
// into constants e.g. in x.expand([3, 4])
if(v->node()->kind() == prim::TupleConstruct)
return v->node()->inputs();
auto & g = *v->owningGraph();
return g.insertNode(g.createTupleUnpack(v))->outputs();
}
inline TypePtr unwrapOptional(TypePtr opt_type) {
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
return unwrap_list_type->getElementType();
}
return opt_type;
}
static inline bool isIntOrFloatUsedAsList(
const Value* value,
const Argument& arg) {
// Look for int[N] or float[N]
auto v_type = value->type();
if (v_type != FloatType::get() && v_type != IntType::get())
return false;
auto arg_type = unwrapOptional(arg.type());
auto list_type = arg_type->cast<ListType>();
return list_type && list_type->getElementType() == v_type && arg.N();
}
inline bool convertibleToList(TypePtr type, TypePtr list_type_) {
auto list_type = list_type_->cast<ListType>();
if(!list_type) {
return false;
}
if(type->isSubtypeOf(list_type_)) {
return true;
}
if(auto tuple = type->cast<TupleType>()) {
return std::all_of(
tuple->elements().begin(),
tuple->elements().end(),
[&](const TypePtr& t) {
return t->isSubtypeOf(list_type->getElementType());
});
}
return false;
}
// applies implict conversion from value trying to turn it into type concrete_type
// it succeeds if the return_value->isSubclassOf(concrete_type)
Value* tryConvertToType(
const SourceRange& loc,
Graph& graph,
TypePtr concrete_type,
Value* value,
bool convert_tensors_to_nums) {
// Allow homogeneous tuples to be casted implicitly to lists of appropriate
// types
if (convertibleToList(value->type(), unwrapOptional(concrete_type)) &&
value->type()->kind() == TypeKind::TupleType) {
auto unpacked = createTupleUnpack(value);
auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
}
if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
if (concrete_type->isSubtypeOf(GeneratorType::get())) {
value = graph.insertNode(graph.createNoneGenerator())->output();
} else if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
// create undefined tensor when None pass to a optional[tensor] formal arg
value = graph.insertNode(graph.createUndefined())->output();
} else if (auto optional_type = concrete_type->cast<OptionalType>()) {
value = graph.insertNode(graph.createNone(optional_type->getElementType()))->output();
}
}
//implicit conversion of tensors to scalars
if(convert_tensors_to_nums && concrete_type->isSubtypeOf(NumberType::get())
&& value->type()->isSubtypeOf(DynamicType::get())) {
auto n = graph.createImplicitTensorToNum(concrete_type, value);
value = graph.insertNode(n)
->setSourceLocation(std::make_shared<SourceRange>(loc))
->output();
}
return value;
}
Value* tryMatchArgument(
const Argument& arg,
Graph& graph,
const SourceRange& loc,
const NamedValue& named_value,
std::function<std::ostream&()> err,
bool convert_tensors_to_nums,
TypeEnv & type_env) {
Value* value = named_value.value(graph);
// some functions that take lists of integers or floats for fixed size arrays
// also allow single ints/floats to be passed in their place.
// the single int/float is then repeated to the length of the list
if (isIntOrFloatUsedAsList(value, arg)) {
std::vector<Value*> repeated(*arg.N(), value);
value = graph.insertNode(graph.createList(value->type(), repeated))->output();
}
const MatchTypeReturn matched_type =
matchTypeVariables(arg.type(), value->type(), type_env);
if (!matched_type.type) {
err() << "could not match type " << value->type()->str() << " to "
<< arg.type()->str() << " in argument '" << arg.name()
<< "': " << matched_type.errMsg << "\n"
<< named_value.locOr(loc);
return nullptr;
}
const auto concrete_type = *matched_type.type;
value = tryConvertToType(loc, graph, concrete_type, value, convert_tensors_to_nums);
if(!value->type()->isSubtypeOf(concrete_type)) {
err() << "expected a value of type " << concrete_type->str() << " for argument '" << arg.name() << "' but found "
<< value->type()->str() << "\n"
<< named_value.locOr(loc);
return nullptr;
}
return value;
}
c10::optional<size_t> findInputWithName(
const std::string& name,
at::ArrayRef<NamedValue> kwargs) {
for(size_t i = 0; i < kwargs.size(); ++i) {
if(kwargs[i].name() == name)
return i;
}
return c10::nullopt;
}
Value* tryCreateList(
TypePtr elem_type,
Graph& graph,
const SourceRange& loc,
at::ArrayRef<NamedValue> varargs,
std::function<std::ostream&()> err,
bool convert_tensor_to_num,
TypeEnv & type_env) {
Argument elem_arg("<varargs>", elem_type);
std::vector<Value*> list_ctor;
for(const auto& a : varargs) {
Value* av = tryMatchArgument(elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
if(!av)
return nullptr;
list_ctor.push_back(av);
}
return graph.insertNode(graph.createList(elem_type, list_ctor))->output();
}
template<class T>
static Value* materializeConstant(T val, Graph& graph,
const SourceRange& r, std::unordered_map<T, Value*>& map) {
auto existing_constant = map.find(val);
if (existing_constant != map.end()) {
return existing_constant->second;
}
WithInsertPoint guard(graph.block()->nodes().front());
auto new_constant = graph.insertConstant(val, r);
map[val] = new_constant;
return new_constant;
}
c10::optional<MatchedSchema> tryMatchSchema(
const FunctionSchema& schema,
const SourceRange& loc,
Graph& graph,
c10::optional<NamedValue> self,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
std::ostream& failure_messages,
bool convert_tensors_to_nums) {
auto err = [&]() -> std::ostream& {
failure_messages << "\nfor operator " << schema << ":\n";
return failure_messages;
};
TypeEnv type_env;
std::vector<Value*> positional_inputs;
std::vector<bool> used_kwarg(kwargs.size(), false);
// if we finish the loop will we have consumed all arguments?
size_t used_args = 0;
for (size_t schema_i = 0; schema_i < schema.arguments().size(); ++schema_i) {
const auto& arg = schema.arguments()[schema_i];
c10::optional<NamedValue> v;
if (arg.name() == "self" && self) {
v = self;
self = c10::nullopt;
} else if (!arg.kwarg_only() && used_args < args.size()) {
// allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
!arg.N() && // it must not be a broadcasting list like int[3], otherwise
// a single int is a valid input
(schema_i + 1 == schema.arguments().size() ||
schema.arguments()[schema_i + 1]
.kwarg_only())) { // must be the last position argument
auto actual_type = args[used_args].value(graph)->type();
if (actual_type->kind() != TypeKind::ListType &&
!convertibleToList(
actual_type,
unwrapOptional(arg.type()))) { // and the actual should not be a list already
auto elem_type = unwrapOptional(arg.type())->expect<ListType>()->getElementType();
Value* list = tryCreateList(
elem_type,
graph,
loc,
at::ArrayRef<NamedValue>(args).slice(used_args),
err,
convert_tensors_to_nums,
type_env);
if (!list)
return c10::nullopt;
used_args = args.size();
positional_inputs.push_back(list);
continue;
}
}
v = args[used_args];
used_args++;
} else if (auto idx = findInputWithName(arg.name(), kwargs)) {
const NamedValue& nv = kwargs[*idx];
if (used_kwarg[*idx]) {
err() << "argument " << nv.name()
<< " specified twice in schema, submit a bug report!\n"
<< nv.locOr(loc);
return c10::nullopt;
}
used_kwarg[*idx] = true;
v = nv;
} else if (arg.default_value()) {
v = NamedValue(*arg.default_value());
} else {
err() << "argument " << schema.arguments()[schema_i].name()
<< " not provided.\n"
<< loc;
return c10::nullopt;
}
Value* positional = tryMatchArgument(
arg, graph, loc, *v, err, convert_tensors_to_nums, type_env);
if (!positional)
return c10::nullopt;
positional_inputs.push_back(positional);
}
// check for unused self argument
if(self != c10::nullopt) {
err() << "provided self argument not used in schema\n";
}
if (schema.is_vararg()) {
for(;used_args < args.size(); ++used_args) {
positional_inputs.push_back(args[used_args].value(graph));
}
}
// check for unused positional arguments
if (used_args < args.size()) {
err() << "expected at most " << used_args << " arguments "
<< "but found " << args.size() << " positional arguments.\n"
<< loc << "\n";
return c10::nullopt;
}
// check for unused kwargs
for (size_t i = 0; i < kwargs.size(); ++i) {
const auto& nv = kwargs[i];
if (!used_kwarg[i]) {
if (!schema.argumentIndexWithName(nv.name())) {
err() << "keyword argument " << nv.name() << " unknown\n";
} else {
err() << "keyword argument " << nv.name() << " specified twice\n";
}
return c10::nullopt;
}
}
auto return_types = fmap(schema.returns(), [&](const Argument& r) {
return evalTypeVariables(r.type(), type_env);
});
return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
}
static std::string prefixLine(const std::string& str, std::string prefix) {
std::stringstream ss;
bool was_newline = true;
for(auto c : str) {
if(was_newline)
ss << prefix;
ss.put(c);
was_newline = c == '\n';
}
return ss.str();
}
// Given a successful match between operator schema and symbol, emit a node
// with the appropriate inputs and outputs.
static Value* emitBuiltinNode(
const MatchedSchema& matched_schema,
const SourceRange& loc,
Graph& graph,
Symbol name) {
auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
->setSourceLocation(std::make_shared<SourceRange>(loc));
for(auto & ret : matched_schema.return_types) {
n->addOutput()->setType(ret);
}
// assert that we did indeed create an op that has implementation
// otherwise schema and dispatch are not in sync
getOperation(n);
return packOutputs(graph, n->outputs());
}
// Search for operators matching the provided symbol name and input types.
// If one is found, emit a node to the graph for that operator.
Value* emitBuiltinCall(
const SourceRange& loc,
Graph& graph,
Symbol name,
c10::optional<NamedValue> self,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
// if true, emitBuiltinCall will throw an exception if this builtin does not exist,
// otherwise it will return nullptr if the builtin is not found.
bool required) {
const auto& variants = getAllOperatorsFor(name);
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
std::stringstream failure_messages;
//first we try to match the schema without any conversion
//if no schema matches then insert ImplicitTensorToNum
for (bool convert_tensors_to_nums : {false, true}) {
// clear previous error messages
failure_messages.str("");
for (const std::shared_ptr<Operator>& op : variants) {
const auto matched_schema = tryMatchSchema(
op->schema(),
loc,
graph,
self,
inputs,
attributes,
failure_messages,
convert_tensors_to_nums);
if (matched_schema) {
return emitBuiltinNode(*matched_schema, loc, graph, name);
}
}
for (Method* method : builtin_functions) {
if (auto result = try_emit_call_to(
graph,
loc,
*method,
self,
inputs,
attributes,
failure_messages,
nullptr,
convert_tensors_to_nums)) {
return packOutputs(graph, *result);
}
}
}
// none of the options worked
if (!required) {
return nullptr;
}
if(variants.size() == 0) {
throw ErrorReport(loc) << "unknown builtin op";
}
throw ErrorReport(loc) << "arguments for call are not valid:\n"
<< prefixLine(failure_messages.str(), " ")
<< "for call at";
}
static Value* ensureInt(const SourceRange& range, Value* v) {
if(!v->type()->isSubtypeOf(IntType::get())) {
throw ErrorReport(range) << "expected a int but found a "
<< v->type()->str();
}
return v;
}
std::shared_ptr<SugaredValue> BuiltinFunction::call(
SourceRange loc,
Method& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
return std::make_shared<SimpleValue>(emitBuiltinCall(
loc, *m.graph(), symbol, self, inputs, attributes, true));
}
inline bool isSupportedListElementType(TypePtr type) {
return type->isSubtypeOf(DynamicType::get()) ||
type->isSubtypeOf(NumberType::get());
}
TypePtr parseTypeFromExpr(Expr expr);
c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(Expr expr);
struct to_ir {
to_ir(
Def def,
Resolver resolver_,
SugaredValuePtr self,
Method& method) // method being constructed
: method(method)
, graph(method.graph())
, def(def)
, resolver(std::move(resolver_))
, self(self)
, environment_stack(nullptr) {
JIT_ASSERT(resolver);
pushFrame(graph->block());
// Type annotations exclude explicitly typing the "self" parameter, so in the
// case that this is a method with self we expect one fewer parameter annotation
// than the number of parameters this Def takes.
if (self && def.decl().params().size() == 0) {
throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
}
auto schema = extractSchemaFromDef(def);
std::vector<Argument> arguments = emitFormalArguments(self, schema);
// body
auto stmts = def.statements();
auto stmts_begin = stmts.begin();
auto stmts_end = stmts.end();
c10::optional<Return> return_stmt;
if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
--stmts_end;
return_stmt = Return(*stmts_end);
}
emitStatements(stmts_begin, stmts_end);
std::vector<Argument> returns = emitReturn(return_stmt, schema);
method.setSchema({def.name().name(), std::move(arguments), std::move(returns)});
// remove any uses of tuples that we inserted that are not needed
LowerSimpleTuples(graph);
ConstantPooling(graph);
}
private:
Method& method;
std::shared_ptr<Graph> graph;
Def def;
Resolver resolver;
SugaredValuePtr self;
std::unordered_map<int64_t, Value*> integral_constants;
std::unordered_map<double, Value*> fp_constants;
// Singly-linked list of environments. This top element contains a member
// `next` that points to the most immediate enclosing scope's value.
std::shared_ptr<Environment> environment_stack;
void pushFrame(Block * b) {
environment_stack = std::make_shared<Environment>(method, resolver, b, environment_stack);
}
std::shared_ptr<Environment> popFrame() {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
return old_frame;
}
std::vector<IValue> evaluateDefaults(const SourceRange& r, const std::vector<Expr>& default_types, const std::vector<Expr>& default_exprs) {
std::vector<IValue> default_values;
if (default_exprs.empty())
return default_values;
// To evaluate the default expressions, we create a graph with no inputs,
// and whose returns are the default values we need.
// We then run constant prop on this graph and check the results are constant.
// This approach avoids having to have separate handling of default arguments
// from standard expressions by piecing together existing machinery for
// graph generation, constant propgation, and constant extraction.
auto tuple_type = Subscript::create(
r,
Var::create(r, Ident::create(r, "Tuple")),
List<Expr>::create(r, default_types));
auto blank_decl =
Decl::create(r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
auto tuple_expr = TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
auto ret = Return::create(r, List<Expr>::create(r, { tuple_expr }));
auto def = Def::create(
r,
Ident::create(r, "defaults"),
blank_decl,
List<Stmt>::create(r, {ret}));
auto m = std::make_shared<Module>();
defineMethodsInModule(m, {def}, {resolver}, nullptr);
m->get_method("defaults").run(default_values);
return default_values;
}
std::vector<Argument> parseArgsFromDecl(Decl decl) {
auto params_begin = decl.params().begin();
auto params_end = decl.params().end();
if (self)
++params_begin;
std::vector<Argument> retval;
std::vector<Expr> default_types;
std::vector<Expr> default_exprs;
// gather any non-empty default arguments
for (auto it = params_begin; it != params_end; ++it) {
auto param = *it;
auto def = param.defaultValue();
if (def.present()) {
default_types.emplace_back(param.type());
default_exprs.emplace_back(def.get());
}
}
auto default_values = evaluateDefaults(decl.range(), default_types, default_exprs);
auto defaults_it = default_values.begin();
for (auto it = params_begin; it != params_end; ++it) {
auto decl_arg = *it;
TypePtr type;
c10::optional<int32_t> N;
//BroadcastList list can only appear at the argument level
if (auto maybe_broad_list = handleBroadcastList(decl_arg.type())) {
type = maybe_broad_list->first;
N = maybe_broad_list->second;
} else {
type = parseTypeFromExpr(decl_arg.type());
N = c10::nullopt;
}
c10::optional<IValue> default_value = c10::nullopt;
if (decl_arg.defaultValue().present()) {
default_value = *defaults_it++;
}
auto arg = Argument(
decl_arg.ident().name(),
type,
N,
default_value,
/*kwarg_only =*/false);
retval.push_back(arg);
}
return retval;
}
std::vector<Argument> parseReturnsFromDecl(Decl decl) {
JIT_ASSERT(decl.return_type().present());
if (handleBroadcastList(decl.return_type().get()))
throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
auto parsed_type = parseTypeFromExpr(decl.return_type().get());
if (auto tuple_type = parsed_type->cast<TupleType>()) {
// Flatten a single return type of type Tuple into its constituent types
std::vector<Argument> retval;
for (auto type_ptr : tuple_type->elements()) {
retval.emplace_back(
"",
type_ptr,
/*N =*/c10::nullopt,
/*default_value =*/c10::nullopt,
/*kwarg_only =*/false);
}
return retval;
} else {
return {Argument(
"",
parsed_type,
/*N =*/c10::nullopt,
/*default_value =*/c10::nullopt,
/*kwarg_only =*/false)};
}
}
FunctionSchema extractSchemaFromDef(const Def &def) {
auto name = def.name().name();
std::vector<Argument> args = parseArgsFromDecl(def.decl());
std::vector<Argument> returns;
bool is_varret;
if (def.decl().return_type().present()) {
returns = parseReturnsFromDecl(def.decl());
is_varret = false;
} else {
is_varret = true;
}
return FunctionSchema(name, args, returns, false, is_varret);
}
std::vector<Argument> emitFormalArguments(SugaredValuePtr self, const FunctionSchema& schema) {
std::vector<Argument> arguments; // for schema
// inputs
auto it = def.decl().params().begin();
auto end = def.decl().params().end();
auto expected_annotation_size = self ? def.decl().params().size() - 1 : def.decl().params().size();
if (schema.arguments().size() != expected_annotation_size) {
throw ErrorReport(def.decl().params().range()) << "Number of type annotations for"
<< " function parameters (" << schema.arguments().size() << ")"
<< " does not match the number of parameters on the function ("
<< expected_annotation_size << ")!";
}
if(self) {
JIT_ASSERT(it != end);
environment_stack->setSugaredVar(def.range(), (*it).ident().name(), self);
++it;
}
size_t arg_annotation_idx = 0;
for(;it != end; ++it) {
auto& name = (*it).ident().name();
// Add the input to the graph
Value *new_input = graph->addInput();
if (meaningfulName(name)) {
new_input->setUniqueName(name);
}
environment_stack->setVar((*it).ident().range(), name, new_input);
// Record the type for the schema and set the Type on the Value*
arguments.push_back(schema.arguments().at(arg_annotation_idx++));
new_input->setType(arguments.back().type());
}
return arguments;
}
std::vector<Argument> emitReturn(c10::optional<Return> return_stmt_, const FunctionSchema& schema) {
// outputs
std::vector<Argument> returns;
if (return_stmt_) {
auto return_stmt = *return_stmt_;
auto results = getValues(return_stmt.values(), true);
// a single return value that is a tuple expands in place:
// return a
if (return_stmt.values().size() == 1 && results.size() == 1) {
auto result = results.at(0);
if(result->type()->cast<TupleType>()) {
results = createTupleUnpack(result).vec();
}
}
if (!schema.is_varret() && schema.returns().size() != results.size()) {
throw ErrorReport(def.range()) << "Number of type annotations for function"
<< " return (" << schema.returns().size() << ") does not match"
<< " the number of returns from the function (" << results.size() << ")!";
}
auto range = return_stmt.range();
size_t return_type_idx = 0;
for (auto r : results) {
TypePtr type = DynamicType::get();
if (!schema.is_varret()) {
type = schema.returns().at(return_type_idx).type();
r = tryConvertToType(range, *graph, type, r, /*convert_tensors_to_nums=*/false);
if (!r->type()->isSubtypeOf(type)) {
throw ErrorReport(return_stmt.range()) << "Return value at position "
<< return_type_idx << " was annotated as having type " << type->str()
<< " but is actually of type " << r->type()->str();
}
return_type_idx++;
}
graph->registerOutput(r);
returns.emplace_back("", type);
}
} else if (schema.returns().size() > 0) {
// schema has returns but there's no return nodes in graph
throw ErrorReport() << "Expected " << schema.returns().size()
<< " return value"
<< (schema.returns().size() > 1 ? "s" : "")
<< " but found no return statement";
}
return returns;
}
void emitStatements(const List<Stmt>& statements) {
return emitStatements(statements.begin(), statements.end());
}
void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
break;
case TK_AUG_ASSIGN:
emitAugAssignment(AugAssign(stmt));
break;
case TK_GLOBAL:
for (auto ident : Global(stmt).names()) {
const auto& name = Ident(ident).name();
environment_stack->setVar(ident.range(), name, graph->addInput(name));
}
break;
case TK_EXPR_STMT: {
auto expr = ExprStmt(stmt).expr();
emitSugaredExpr(expr, 0);
}
break;
case TK_RAISE:
emitRaise(Raise(stmt).range());
break;
case TK_ASSERT:
emitAssert(Assert(stmt));
break;
case TK_RETURN:
throw ErrorReport(stmt) << "return statements can appear only at the end "
<< "of the function body";
break;
case TK_PASS:
// Emit nothing for pass
break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
}
}
}
std::shared_ptr<Environment> emitSingleIfBranch(
Block* b,
const List<Stmt> branch) {
pushFrame(b);
WithInsertPoint guard(b);
emitStatements(branch);
return popFrame();
}
Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
return graph
->create(kind, n_outputs)
->setSourceLocation(std::make_shared<SourceRange>(loc));
}
Value* emitTernaryIf(const TernaryIf& expr) {
Value* cond_value = emitCond(expr.cond());
auto true_expr = [&] {
return emitExpr(expr.true_expr());
};
auto false_expr = [&] {
return emitExpr(expr.false_expr());
};
return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
}
Value* emitShortCircuitIf(
const SourceRange& loc,
const TreeRef & first_expr,
const TreeRef & second_expr,
bool is_or) {
Value * first_value = emitCond(Expr(first_expr));
auto get_first_expr = [first_value] {
return first_value;
};
auto get_second_expr = [&] {
return emitCond(Expr(second_expr));
};
// if this is an OR, eval second expression if first expr is False.
// If this is an AND, eval second expression if first expr is True
if (is_or) {
return emitIfExpr(loc, first_value, get_first_expr, get_second_expr);
} else {
return emitIfExpr(loc, first_value, get_second_expr, get_first_expr);
}
}
Value* emitIfExpr(const SourceRange& range, Value * cond_value,
std::function<Value*()> true_expr, std::function<Value*()> false_expr) {
Node* n = graph->insertNode(create(prim::If, range, 0));
n->addInput(cond_value);
auto* true_block = n->addBlock();
auto* false_block = n->addBlock();
auto emit_if_expr = [this](Block* b, std::function<Value*()> expr_value) {
pushFrame(b);
WithInsertPoint guard(b);
Value* out_val = expr_value();
b->registerOutput(out_val);
popFrame();
};
emit_if_expr(true_block, true_expr);
emit_if_expr(false_block, false_expr);
auto true_type = unshapedType(true_block->outputs().at(0)->type());
auto false_type = unshapedType(false_block->outputs().at(0)->type());
if (*true_type != *false_type) {
throw ErrorReport(range)
<< "if-expression's true branch has type " << true_type->str()
<< " but false branch has type " << false_type->str();
}
// Add op outputs
auto expr_value = n->addOutput()->setType(true_type); // Resulting value
return expr_value;
}
Value* emitCond(Expr cond) {
Value* v = emitExpr(cond);
if (!v->type()->isSubtypeOf(BoolType::get())) {
ErrorReport error(cond);
error << "expected a boolean expression for condition but found "
<< v->type()->str();
if (v->type()->isSubtypeOf(DynamicType::get())) {
error << ", to use a tensor in a boolean"
<< " expression, explicitly cast it with `bool()`";
}
throw error;
}
return v;
}
void emitIf(const If& stmt) {
Value* cond_value = emitCond(stmt.cond());
Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
n->addInput(cond_value);
auto* true_block = n->addBlock();
auto* false_block = n->addBlock();
// Emit both blocks once to get the union of all mutated values
auto save_true = emitSingleIfBranch(true_block, stmt.trueBranch());
auto save_false = emitSingleIfBranch(false_block, stmt.falseBranch());
// In python, every variable assigned in an if statement escapes
// the scope of the if statement (all variables are scoped to the function).
// Script is a subset of python: we consider variables to be in scope
// as long as there is a definition of the variable along all paths
// through the if statemnent
// ----
// if ...:
// a =
// else:
// ...
// ... = a # error, a is not defined along all paths
// ----
// if ...:
// a =
// else:
// a =
// ... = a # OK, a is defined along all paths
// ----
// a = ...
// if ...:
// a =
// ... = a # OK, a is defined along all paths
//ordered set, because we want deterministic graph output
std::set<std::string> mutated_variables;
for(auto & v : save_true->definedVariables()) {
if(save_false->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
for(auto & v : save_false->definedVariables()) {
if(save_true->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
// Register outputs in each block
for (const auto& x : mutated_variables) {
auto tv = save_true->getVar(x, stmt.range());
auto fv = save_false->getVar(x, stmt.range());
auto unified = unifyTypes(tv->type(), fv->type());
// attempt to unify the types. we allow variables to be set to different types
// in each branch as long as that variable is not already in scope,
// or if that variable does not get used later. here, we save the error
// so that the error message will be more informative in the case that is
// used later. When a is accessed in (a + 1), the error will get printed
// if cond:
// a = 1
// else:
// a = tensor
// b = a + 1
//
if (!unified) {
ErrorReport error(stmt);
error << "Type mismatch: " << x << " is set to type " << tv->type()->str() << " in the true branch"
<< " and type " << fv->type()->str() << " in the false branch";
if (save_true->findInParentFrame(x) || save_false->findInParentFrame(x)) {
throw error;
} else {
// error gets saved in the lowest environment because all
// variables are scoped to the function. doesn't matter if this accessed
// through save_true or save_false
save_true->setVariableTypeError(x, error.what());
continue;
}
}
true_block->registerOutput(tv);
false_block->registerOutput(fv);
environment_stack->setVar(stmt.range(), x, n->addOutput()->setType(*unified));
}
}
// *********************** Loop Operators ************************************
// Emits a loop operators conforming to the semantics specified at
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#experimental-loop
// TODO: implement scan_outputs
// the format of the Loop instruction is:
// loop_carried_outputs* = Loop(max_trip_count, start_condition,
// loop_carried_inputs*)
// block0(loop_counter, loop_carried_block*) {
// <body>
// -> (continue_condition,
// loop_carried_block_outputs*)
// }
// all loop_carried_... lists are the same length and represent the value of
// loop-carried variables whose definitions are updated as the loop executes
// in a way that ensure single static assignment.
void emitLoopCommon(
SourceRange range,
c10::optional<Expr> max_trip_count,
c10::optional<Expr> cond,
const List<Stmt>& body,
c10::optional<Ident> itr_ident) {
Node* n = graph->insertNode(create(prim::Loop, range, 0));
Value *max_trip_count_val, *cond_val;
{
WithInsertPoint guard(n);
if (max_trip_count) {
max_trip_count_val = ensureInt(
max_trip_count->range(), emitExpr(max_trip_count.value()));
} else {
max_trip_count_val =
materializeConstant(std::numeric_limits<int64_t>::max(), *graph, range, integral_constants);
}
if (cond) {
cond_val = emitCond(cond.value());
} else {
cond_val = graph->insertConstant(true, range);
}
}
n->addInput(max_trip_count_val);
n->addInput(cond_val);
auto* body_block = n->addBlock();
Value* trip_count = body_block->addInput()->setType(IntType::get()); // Iteration num
{
pushFrame(body_block);
if (itr_ident) {
environment_stack->setVar(itr_ident->range(), itr_ident->name(), trip_count);
}
WithInsertPoint guard(body_block);
emitStatements(body);
// Also emit the conditional
if (cond) {
Value* body_cond_value = emitCond(cond.value());
body_block->registerOutput(body_cond_value);
} else {
Value* cond_value_dummy = graph->insertConstant(true, range);
body_block->registerOutput(cond_value_dummy);
}
auto body_frame = popFrame();
auto outer_frame = environment_stack;
// Add block outputs to correspond to each captured input
// some of these will be removed.
for (const auto& x : body_frame->captured_inputs) {
auto fv = body_frame->getValueInThisFrame(range, x);
body_block->registerOutput(fv);
}
// Remove inputs for values that did not mutate within the
// block
body_frame->deleteExtraInputs(range);
// register node inputs/outputs for the true loop carried deps,
for(size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
auto x = body_frame->captured_inputs[i];
n->addInput(outer_frame->getVar(x, range));
// body_block->inputs(): loop_counter, lcd0, lcd1, ...
// captured_inputs: lcd0, lcd1, ...
auto typ = body_block->inputs()[i + 1]->type();
outer_frame->setVar(range, x, n->addOutput()->setType(typ));
}
}
}
void emitForRange(SourceRange range, const Ident& target, const List<Expr>& args, const List<Stmt>& body) {
// TODO: start, stop, step loop
if (args.size() != 1) {
throw ErrorReport(range)
<< "range() expects 1 argument but got " << args.size();
}
emitLoopCommon(range, {args[0]}, {}, body, target);
}
void emitFor(const For& stmt) {
// For now, we only support range loops. e.g. for i in range(3): ...
auto targets = stmt.targets();
auto itrs = stmt.itrs();
auto body = stmt.body();
if (stmt.itrs().size() != 1) {
throw ErrorReport(stmt)
<< "List of iterables is not supported currently.";
}
if (targets.size() != 1) {
throw ErrorReport(stmt) << "Iteration variable unpacking is not supported";
}
if (targets[0].kind() != TK_VAR) {
throw ErrorReport(targets[0]) << "unexpected expression in variable initialization of for loop";
}
auto target = Var(targets[0]).name();
// match range(<expr>) style loops
// itrs must consist of a single Apply node
if (itrs[0].kind() == TK_APPLY) {
Apply range_iterator = Apply(itrs[0]);
if (range_iterator.callee().kind() == TK_VAR) {
Var var = Var(range_iterator.callee());
if (var.name().name() == "range") {
return emitForRange(stmt.range(), target, range_iterator.inputs(), body);
}
}
}
// it isn't a range(<expr>) loop, treat it as a sugared value that maybe can be
// unrolled
auto sv = emitSugaredExpr(itrs[0], 1);
auto instances = sv->asTuple(stmt.range(), method);
const std::string& target_name = target.name();
pushFrame(environment_stack->block());
for(auto inst : instances) {
environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
emitStatements(body);
}
for (const auto & n : environment_stack->definedVariables()) {
if (environment_stack->findInParentFrame(n)) {
environment_stack->next->setVar(stmt.range(), n, environment_stack->getVar(n, stmt.range()));
}
}
popFrame();
}
void emitWhile(const While& stmt) {
auto cond = stmt.cond();
emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
}
// Currently we do not support assigning exceptions to variables,
// a = Exception("hi")
// raise a
//
// We ignore the expression following raise
//
// NYI: add exception logic to control-flow nodes
// if True:
// a = 1
// else
// raise Exception("Hi")
// print(a)
void emitRaise(const SourceRange& loc) {
const std::string exception = "Exception";
auto string_input = insertConstant(*graph, exception, loc);
graph->insert(prim::RaiseException, {string_input}, {}, loc);
}
void emitAssert(const Assert& stmt) {
Value* cond_value = emitCond(stmt.test());
Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
n->addInput(cond_value);
/* true_block =*/n->addBlock();
auto* false_block = n->addBlock();
//if assert test is false throw exception
pushFrame(false_block);
WithInsertPoint guard(false_block);
emitRaise(stmt.range());
popFrame();
}
// Validate that the `lhs` Expr's in an assignment statement are valid. That
// is:
//
// 1) All lhs Expr's are either Var or Starred nodes
// 2) There is at most one Starred node in the lhs Expr
// 3) A Starred node can only appear when there is another non-Starred lhs Expr
// Concretely this means that `*abc = func()` is illegal. Unpacking all
// outputs into a tuple is covered by `abc = func()`.
bool calcNumStarredUnpack(const List<Expr>& lhs, const SourceRange& r) {
size_t num_normal_assign = 0;
size_t num_starred = 0;
for (const auto& assignee : lhs) {
if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT) {
num_normal_assign++;
} else if (assignee.kind() == TK_STARRED) {
num_starred++;
} else {
throw ErrorReport(assignee) << "lhs of assignment must be a variable, "
<< "subscript, or starred expression.";
}
}
if (num_starred > 1) {
throw ErrorReport(r)
<< "Only one starred expression is allowed on the lhs.";
}
if (num_starred > 0 && num_normal_assign == 0) {
throw ErrorReport(r) << "A Starred expression may only appear on the "
<< "lhs within the presence of another non-starred"
<< " expression.";
}
return num_starred;
}
// Get the appropriate builtin op for this augmented assignment
// If the RHS is a tensor, return the corresponding ATen in-place op
// If it's a list of scalars, then return the corresponding list augment op
Symbol getAugOp(const AugAssign& stmt, bool isTensor) {
switch (stmt.aug_op()) {
case '+':
return isTensor ? aten::add_ : aten::add;
case '-':
return isTensor ? aten::sub_ : aten::sub;
case '/':
return isTensor ? aten::div_ : aten::div;
case '*':
return isTensor ? aten::mul_ : aten::mul;
default:
throw ErrorReport(stmt) << "Unknown augmented assignment: "
<< kindToString(stmt.aug_op());
}
}
// Emit nodes for augmented assignments like `+=`
void emitAugAssignment(const AugAssign& stmt) {
switch (stmt.lhs().kind()) {
case TK_VAR: {
emitAugAssignmentToVar(stmt);
} break;
case '.': {
emitAugAssignmentToSelectVar(stmt);
} break;
case TK_SUBSCRIPT: {
emitAugAssignmentToSubscript(stmt);
} break;
default:
throw ErrorReport(stmt.lhs())
<< "unexpected expression on "
<< "left-hand side of augmented assignment.";
}
}
// This will be called when there is a class param or module buffer
// mutation which make the LHS of the expr be a select expression
//
// Example like:
// class A(Module):
// def __init__():
// self.register_buffer("running_var", torch.zeros(1))
//
// def forward():
// self.num_batches += 1
//
// In this case we will only consider the scenario that the module
// buffer type is a tensor, and we emit the corresponding tensor
// in place op, and throw error for other unsupported types
void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
const auto lhs = Select(stmt.lhs());
const auto lhsSugaredVar = environment_stack->getSugaredVar(Var(lhs.value()).name());
const auto lhsValue = lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())->asValue(lhs.range(), method);
if (lhsValue->type()->isSubtypeOf(DynamicType::get())) {
// for module parameter/buffer assignment, only consider tensor types,
// emit the corresponding in-place op
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
self,
{rhs},
{},
/*required=*/true);
} else {
throw ErrorReport(stmt.lhs())
<< "left-hand side of augmented assignment to module "
<< "parameters/buffers can only be tensor types";
}
}
void emitAugAssignmentToVar(const AugAssign& stmt) {
const auto lhs = Var(stmt.lhs());
const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
->asValue(lhs.range(), method);
if (lhsValue->type()->isSubtypeOf(DynamicType::get())) {
// for tensors, emit the corresponding in-place op
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
const auto output = emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
self,
{rhs},
{},
/*required=*/true);
environment_stack->setVar(lhs.range(), lhs.name().name(), output);
} else {
// for primitive types, desugar into a simple assignment
// e.g. foo += 1 becomes foo.2 = foo + 1
Ident lhs = Var(stmt.lhs()).name();
Expr expr = BinOp::create(
stmt.range(),
stmt.aug_op(),
Var::create(lhs.range(), lhs),
stmt.rhs());
environment_stack->setVar(lhs.range(), lhs.name(), emitExpr(expr));
}
}
void emitAugAssignmentToSubscript(const AugAssign& stmt) {
// Process the base list value
const auto lhs = Subscript(stmt.lhs());
const auto sliceable = emitExpr(lhs.value());
if (sliceable->type()->isSubtypeOf(DynamicType::get())) {
// If it's a tensor, just fully evaluate the subscript operation and emit
// an in-place assignment
std::vector<Value*> tensorIndices;
Value* sliced;
std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
lhs.range(), sliceable, lhs.subscript_exprs());
const auto slicedArg = NamedValue(stmt.lhs().range(), "self", sliced);
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
if (tensorIndices.size() == 0) {
// Common case: we only tried to index with int and slices. Emit the
// correct augmented assignment op to the sliced value
emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
slicedArg,
{rhs},
{},
/*required=*/true);
} else {
// Special case: we tried to do "advanced indexing". Lower this expr
// into `index` and `index_put_` ops
const auto indices = graph->insertNode(
graph->createList(DynamicType::get(), tensorIndices))->output();
const auto indexed =
graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
const auto augmented = emitBuiltinCall(
stmt.range(),
*method.graph(),
getAugOp(stmt, /*isTensor=*/true),
indexed,
{rhs},
{},
/*required=*/true);
graph->insert(
aten::index_put_,
{slicedArg, indices, augmented},
{},
stmt.range());
}
} else {
// Otherwise, it should be a list. Lower this expression into:
// list.set_item(get_item(idx).add_(value))
// similar to how Python handles things.
const auto listType = sliceable->type()->cast<ListType>();
JIT_ASSERT(listType != nullptr);
bool isTensorList =
listType->getElementType()->isSubtypeOf(DynamicType::get());
// Get the idx to augment
const auto subscriptExprs = lhs.subscript_exprs();
if (subscriptExprs.size() != 1) {
throw ErrorReport(subscriptExprs)
<< "Sliced expression not yet supported for"
<< " subscripted list augmented assignment. "
<< "File a bug if you want this.";
}
const auto idxValue = emitExpr(subscriptExprs[0]);
const auto listArg = NamedValue(lhs.value().range(), "list", sliceable);
const auto idxArg = NamedValue(subscriptExprs.range(), "idx", idxValue);
const auto valueArg =
NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs()));
const auto getItem =
graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
const auto augmentedItem = graph->insert(
getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
graph->insert(
aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range());
}
}
// Emit mutating assignments like `foo[0] = bar`
void emitSubscriptAssign(
const SourceRange& stmtRange,
const Subscript& lhs,
const Expr& rhs) {
emitSubscriptAssign(
stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
}
void emitSubscriptAssign(
const SourceRange& stmtRange,
const Subscript& lhs,
const NamedValue& rhs) {
// First check the base value.
auto sliceable = emitExpr(lhs.value());
// If it's a tensor, copy the RHS data into it
if (sliceable->type()->isSubtypeOf(DynamicType::get())) {
std::vector<Value*> tensorIndices;
Value* sliced;
// Handle multi-dimensional slicing: first emit int/slice indexing
// TODO: the Python equivalent code has special-cased copy_to
// broadcasting to match NumPy semantics (see PR#4853). We can't
// replicate that without knowing the size of the Tensor; so really that
// code should be moved into the aten function
std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
lhs.range(), sliceable, lhs.subscript_exprs());
const auto slicedArg = NamedValue(lhs.range(), sliced);
if (tensorIndices.size() == 0) {
// Common case: we only tried to index with int and slices. Copy the
// RHS into the resulting tensor.
graph->insert(aten::copy_, {slicedArg, rhs}, {}, stmtRange);
} else {
// Special case: we tried to do "advanced indexing" with a tensor.
// Dispatch to `aten::index_put_`.
const auto indices = graph->insertNode(
graph->createList(DynamicType::get(), tensorIndices))->output();
graph->insert(
aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange);
}
// Otherwise, this is a list. Dispatch to aten::_set_item to both select and
// assign
} else {
const auto subscript = lhs.subscript_exprs();
if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
throw ErrorReport(subscript)
<< "Sliced expression not yet supported for"
<< " subscripted list assignment. "
<< "File a bug if you want this.";
}
std::vector<NamedValue> args;
args.emplace_back(lhs.value().range(), "list", sliceable);
args.emplace_back(
lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0]));
args.push_back(rhs);
graph->insert(aten::_set_item, args, {}, stmtRange);
}
}
void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
size_t n_binders = tl.inputs().size();
bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
if(starred_unpack)
n_binders--;
auto output = emitSugaredExpr(rhs, n_binders);
auto outputs = output->asTuple(
rhs.range(),
method,
starred_unpack ? c10::nullopt : c10::optional<size_t>{n_binders});
if(outputs.size() < n_binders) {
throw ErrorReport(tl)
<< "need " << (starred_unpack ? "at least " : "")
<< n_binders << " values to unpack but found only "
<< outputs.size();
}
if(outputs.size() > n_binders && !starred_unpack) {
throw ErrorReport(tl)
<< "too many values to unpack: need " << n_binders << " but found "
<< outputs.size();
}
int i = 0;
for (auto assignee : tl.inputs()) {
switch (assignee.kind()) {
case TK_SUBSCRIPT:
emitSubscriptAssign(
rhs.range(),
Subscript(assignee),
NamedValue(
rhs.range(), outputs.at(i)->asValue(rhs.range(), method)));
i++;
break;
case TK_VAR:
environment_stack->setSugaredVar(assignee.range(), Var(assignee).name().name(), outputs.at(i));
i++;
break;
case TK_STARRED: {
auto var = Starred(assignee).expr();
if (var.kind() != TK_VAR) {
throw ErrorReport(var) << "Cannot pack a tuple into a non-variable.";
}
size_t n_matched = outputs.size() - n_binders;
ArrayRef<std::shared_ptr<SugaredValue>> outputs_ref = outputs;
auto values = fmap(outputs_ref.slice(i, n_matched), [&](const std::shared_ptr<SugaredValue>& v) {
return v->asValue(assignee.range(), method);
});
auto tup = graph->insertNode(graph->createTuple(values))->output();
environment_stack->setVar(
var.range(), Var(var).name().name(), tup);
i += n_matched;
} break;
default:
throw ErrorReport(assignee) << "unexpected expression on the left-hand side";
}
}
}
void emitAssignment(const Assign& stmt) {
switch(stmt.lhs().kind()) {
case TK_VAR: {
auto v = Var(stmt.lhs());
environment_stack->setSugaredVar(
v.range(), v.name().name(), emitSugaredExpr(stmt.rhs(), 1));
} break;
case TK_TUPLE_LITERAL:
emitTupleAssign(TupleLiteral(stmt.lhs()), stmt.rhs());
break;
case TK_SUBSCRIPT:
emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), stmt.rhs());
break;
default:
throw ErrorReport(stmt.lhs()) << "unexpected expression on left-hand side of assignment.";
}
}
NodeKind getNodeKind(int kind, int ninputs) {
switch (kind) {
case '+':
return aten::add;
case '-':
return aten::sub;
case TK_UNARY_MINUS:
return aten::neg;
case '*':
return aten::mul;
case TK_POW:
return aten::pow;
case '@':
return aten::matmul;
case TK_STARRED:
return prim::Starred;
case '/':
return aten::div;
case '%':
return aten::remainder;
case TK_NE:
return aten::ne;
case TK_EQ:
return aten::eq;
case '<':
return aten::lt;
case '>':
return aten::gt;
case TK_LE:
return aten::le;
case TK_GE:
return aten::ge;
case TK_AND:
return aten::__and__;
case TK_OR:
return aten::__or__;
case TK_IS:
return aten::__is__;
case TK_ISNOT:
return aten::__isnot__;
case TK_NOT:
return aten::__not__;
case TK_FLOOR_DIV:
return aten::floordiv;
case '&':
return aten::__and__;
case '|':
return aten::__or__;
case '^':
return aten::__xor__;
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
}
}
std::vector<NamedValue> getNamedValues(
TreeList trees,
bool maybe_unpack) {
std::vector<NamedValue> values;
for (const auto& tree : trees) {
if(maybe_unpack && tree->kind() == TK_STARRED) {
auto starred = Starred(tree);
auto entries = emitSugaredExpr(starred.expr(), 1)->asTuple(starred.range(), method);
for(auto entry : entries) {
values.emplace_back(
tree->range(), entry->asValue(starred.range(), method));
}
} else {
values.emplace_back(tree->range(), emitExpr(Expr(tree)));
}
}
return values;
}
std::vector<NamedValue> getNamedValues(
List<Expr> trees,
bool maybe_unpack) {
return getNamedValues(trees.tree()->trees(), maybe_unpack);
}
std::vector<Value*> getValues(
TreeList trees,
bool maybe_unpack) {
return toValues(*graph, getNamedValues(trees, maybe_unpack));
}
std::vector<Value*> getValues(
List<Expr> trees,
bool maybe_unpack) {
return getValues(trees.tree()->trees(), maybe_unpack);
}
std::vector<NamedValue> emitAttributes(const List<Attribute> attributes) {
return fmap(attributes, [&](const Attribute& attr) {
return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value()));
});
}
std::shared_ptr<SugaredValue> emitApplyExpr(Apply &apply, size_t n_binders) {
auto sv = emitSugaredExpr(apply.callee(), 1);
auto loc = apply.callee().range();
if (auto fork_value = dynamic_cast<ForkValue*>(sv.get())) {
auto& trees = apply.inputs().tree()->trees();
if (trees.size() < 1) {
throw ErrorReport(loc) << "Expected at least one argument to fork()";
}
auto forked = emitSugaredExpr(Expr(trees[0]), 1);
TreeList sliced_trees(trees.begin() + 1, trees.end());
auto inputs = getNamedValues(sliced_trees, true);
auto attributes = emitAttributes(apply.attributes());
return emitForkExpr(loc, forked, inputs, attributes);
} else if (auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
if (apply.inputs().size() != 2) {
throw ErrorReport(loc)
<< "expected exactly two arguments to attribute but found "
<< apply.inputs().size();
}
if (apply.attributes().size() > 0) {
throw ErrorReport(loc) << "attribute takes no keyword arguments";
}
TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
Value* expr = tryConvertToType(
apply.range(),
*graph,
type,
emitExpr(apply.inputs()[1], type),
/*convert_tensors_to_nums=*/true);
if (!expr->type()->isSubtypeOf(type)) {
throw ErrorReport(apply.inputs())
<< "expected an expression of type " << type->python_str()
<< " but found " << expr->type()->python_str();
}
return std::make_shared<SimpleValue>(expr);
} else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
if (apply.attributes().size() > 0) {
throw ErrorReport(loc) << "getattr takes no keyword arguments";
}
if (apply.inputs().size() != 2) {
throw ErrorReport(loc) << "getattr expects 2 inputs";
}
auto obj = emitSugaredExpr(apply.inputs()[0], 1);
auto selector = apply.inputs()[1];
if (selector.kind() != TK_STRINGLITERAL) {
throw ErrorReport(loc) << "getattr's second argument must be a string literal";
}
const std::string& name = StringLiteral(selector).text();
return obj->attr(apply.range(), method, name);
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());
return sv->call(loc, method, inputs, attributes, n_binders);
}
}
Value* emitExpr(Expr tree, TypePtr type_hint = nullptr) {
return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
}
NodeKind reverseComparision(NodeKind kind) {
if (kind == aten::lt) {
return aten::gt;
} else if (kind == aten::le) {
return aten::ge;
} else if (kind == aten::gt) {
return aten::lt;
} else if (kind == aten::ge) {
return aten::le;
}
throw std::runtime_error("reverseComparision: unsupported NodeKind. File a bug");
}
// any expression that can produce a SugaredValue is handled here
// expressions that only return a single Value* are handled in emitSimpleExpr
// type_hint is set if there is a type that this value is expected to be
// e.g. a : List[int] = []
// or a = torch.jit.annotate(List[int], [])
// the caller is responsible for checking that the result matches type_hint
// emitSugaredExpr is free to ignore it.
std::shared_ptr<SugaredValue> emitSugaredExpr(Expr tree, size_t n_binders, TypePtr type_hint=nullptr) {
switch(tree.kind()) {
case TK_VAR:
return environment_stack->getSugaredVar(Var(tree).name());
case '.': {
auto select = Select(tree);
auto sv = emitSugaredExpr(select.value(), 1);
return sv->attr(select.range(), method, select.selector().name());
}
case TK_APPLY: {
auto apply = Apply(tree);
return emitApplyExpr(apply, n_binders);
} break;
default:
return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
}
}
Value * emitNegate(const TreeRef& tree) {
const auto& inputs = tree->trees();
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
auto neg_val = emitBuiltinCall(
tree->range(),
*method.graph(),
aten::neg,
c10::nullopt,
named_values,
{},
/*required=*/true);
// constant fold the input if possible
auto maybe_constant_input = toIValue(neg_val->node()->input());
if (!maybe_constant_input) {
return neg_val;
}
auto op = getOperation(neg_val->node());
Stack stack;
stack.push_back(*maybe_constant_input);
op(stack);
JIT_ASSERT(stack.size() == 1);
return graph->insertConstant(stack[0], tree->range());
}
// This function extract a new graph from its original subgraph
std::shared_ptr<SugaredValue> emitForkExpr(
SourceRange loc,
const std::shared_ptr<SugaredValue> &forked,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes) {
// Build the fork node without inputs
auto fork_node = method.graph()->insertNode(method.graph()->create(prim::fork, 1))
->setSourceLocation(std::make_shared<SourceRange>(loc));
auto body_block = fork_node->addBlock();
// Build a template of the graph to be executed
Value *node_output;
{
WithInsertPoint guard(body_block);
auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
auto fn_simple_output = fn_sugared_output->asValue(loc, method);
body_block->registerOutput(fn_simple_output);
node_output = fork_node->output()->setType(FutureType::create(fn_simple_output->type()));
}
// Fork a new graph from its orignal owning graph
auto forked_graph = std::make_shared<Graph>();
// Make sure we capture everything in the new graph.
// The uncaptured values will be added to the fork signature.
std::unordered_map<Value*, Value*> uncaptures_map;
auto env = [&](Value* v) -> Value* {
if (!uncaptures_map.count(v)) {
// Capture values for both graphs
uncaptures_map[v] = forked_graph->addInput()->copyMetadata(v);
fork_node->addInput(v);
}
return uncaptures_map[v];
};
forked_graph->block()->cloneFrom(body_block, env);
// Separate the subgraph and clean up the orignal one
fork_node->g_(attr::Subgraph, forked_graph);
fork_node->eraseBlock(0);
return std::make_shared<SimpleValue>(node_output);
}
Value* emitSimpleExpr(
const TreeRef& tree,
TypePtr type_hint = nullptr) {
switch (tree->kind()) {
case '@':
case TK_POW:
case TK_IS:
case TK_ISNOT:
case TK_NOT:
case TK_NE:
case TK_EQ:
case '<':
case '>':
case TK_LE:
case TK_GE:
case '*':
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
case TK_FLOOR_DIV: {
const auto& inputs = tree->trees();
auto kind = getNodeKind(tree->kind(), inputs.size());
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
return emitBuiltinCall(
tree->range(),
*method.graph(),
kind,
c10::nullopt,
named_values,
{},
/*required=*/true);
}
case TK_UNARY_MINUS: {
return emitNegate(tree);
}
case TK_AND:
case TK_OR: {
const auto& inputs = tree->trees();
return emitShortCircuitIf(
tree->range(),
inputs[0],
inputs[1],
tree->kind() == TK_OR);
}
case TK_STARRED: {
throw ErrorReport(tree) << "Unexpected starred expansion. File a bug report.";
}
case TK_CONST: {
return emitConst(Const(tree));
} break;
case TK_TRUE: {
return graph->insertConstant(true, tree->range());
} break;
case TK_FALSE: {
return graph->insertConstant(false, tree->range());
} break;
case TK_NONE: {
return graph->insertConstant(IValue(), tree->range());
} break;
case TK_SUBSCRIPT: {
return emitSubscript(Subscript(tree));
} break;
case TK_IF_EXPR: {
return emitTernaryIf(TernaryIf(tree));
} break;
case TK_STRINGLITERAL: {
return emitStringLiteral(StringLiteral(tree));
} break;
case TK_LIST_LITERAL: {
auto ll = ListLiteral(tree);
auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
// determine the element type of the list
// if we have a type hint of List[T], use T
// if the list is non-empty use type_of(list[0])
// otherwise assume it is List[Tensor]
TypePtr elem_type = DynamicType::get();
if (type_hint && type_hint->kind() == TypeKind::ListType) {
elem_type = type_hint->expect<ListType>()->getElementType();
} else if (!values.empty()) {
elem_type = values.at(0)->type();
}
for (auto v : values) {
if (v->type() != elem_type) {
throw ErrorReport(tree)
<< "Lists must contain only a single type, expected: "
<< *elem_type << " but found " << *v->type() << " instead";
}
}
Value* result = graph->insertNode(graph->createList(elem_type, values))
->output();
return result;
} break;
case TK_TUPLE_LITERAL: {
auto ll = TupleLiteral(tree);
auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
return graph->insertNode(graph->createTuple(values))->output();
} break;
default:
throw ErrorReport(tree) << "NYI: " << tree;
break;
}
}
Value* emitConst(const Const& c) {
if (c.isFloatingPoint())
return materializeConstant(c.asFloatingPoint(), *graph, c.range(), fp_constants);
else
return materializeConstant(c.asIntegral(), *graph, c.range(), integral_constants);
}
Value* emitStringLiteral(const StringLiteral& c) {
return insertConstant(*graph, c.text(), c.range());
}
// Desugars select indexing: tensor[i] -> tensor.select(dim, i)
Value* emitSelect(
const SourceRange& loc,
Value* input,
int64_t dim,
Value* index) {
return emitBuiltinCall(
loc, *graph, aten::select, c10::nullopt,
{input, graph->insertConstant(dim, loc), index}, {}, true);
}
// Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end, 1)
Value* emitSlice(
const SourceRange& loc,
Value* input,
c10::optional<int64_t> dim, // Only used for tensor slicing
const SliceExpr& slice) {
std::vector<NamedValue> args;
args.reserve(4);
args.emplace_back(loc, "self", input);
// XXX: If list slicing becomes more complicated or stops using
// aten::slice, we should separate it from this function.
if (dim) {
JIT_ASSERT(input->type()->isSubtypeOf(DynamicType::get()));
args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), loc));
} else {
JIT_ASSERT(!input->type()->isSubtypeOf(DynamicType::get()));
}
args.emplace_back(loc, "begin", emitExpr(Expr(slice.startOr(0))));
const auto has_end = slice.end().present();
if (has_end) {
args.emplace_back(loc, "end", emitExpr(Expr(slice.end().get())));
}
if (input->type()->cast<TupleType>()) {
if (has_end) {
return emitTupleSlice(loc, args[0], args[1], /*end*/args[2]);
} else {
return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
}
}
NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, loc));
return emitBuiltinCall(loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
}
Value* emitIndex(
const SourceRange& loc,
Value* input,
at::ArrayRef<Value*> indices) {
auto* index = graph->insertNode(
graph->createList(DynamicType::get(), indices))->output();
return emitBuiltinCall(loc, *graph, aten::index, c10::nullopt, {input, index}, {}, true);
}
// Emits multidimensional slicing with int and slice indices.
// Returns:
// - Value*: the input after it has been indexed by int and slice indices.
// - vector<Value*>: A list of tensor Value* indices that have not been applied yet.
// Should be NULL at indices where sliceable (post-slicing) isn't indexed by a tensor.
std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
const SourceRange& loc,
Value* sliceable,
const List<Expr>& subscript_exprs) {
std::vector<Value*> tensor_indices;
size_t dim = 0;
auto handle_tensor = [&](Value* tensor) {
// NB: tensor_indices can have NULL holes because of how at::index works.
tensor_indices.resize(dim + 1);
tensor_indices[dim] = tensor;
dim++;
};
for (const auto & subscript_expr : subscript_exprs) {
if (subscript_expr.kind() == TK_SLICE_EXPR) {
sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr));
++dim;
continue;
}
auto index = emitExpr(subscript_expr);
if (index->type() == IntType::get()) {
sliceable = emitSelect(loc, sliceable, dim, index);
continue;
} else if (index->type()->isSubtypeOf(DynamicType::get())) {
handle_tensor(index);
continue;
}
throw ErrorReport(loc)
<< "Unsupported operation: indexing tensor with unsupported index type "
<< index->type()->str() << ". Only ints, slices, and tensors are supported.";
}
// at::index takes in a TensorList where some tensors can be undefined.
// Convert NULL tensorIndices to undefined tensors to pass to at::index.
for (auto& index : tensor_indices) {
if (index == nullptr) {
index = graph->insertNode(graph->createUndefined())->output();
}
}
return std::make_pair(sliceable, tensor_indices);
}
// Desugars multidim slicing into slice/select/index calls.
//
// XXX: Errors in user code are not elegantly reported.
// Let's say someone were to do the following:
// @torch.jit.script
// def fn(x):
// return x[0, 1]
// fn(torch.randn(5))
// Because we desugar this into two aten::select ops, the error message
// complains about aten::select failing rather than there "not being
// enough dimensions to index".
//
// The strategy is to slice and select the tensor for int and slices first
// in one pass and then apply at::index on the result of the slicing/selecting.
// Call the tensor after we've applied slice / select the `sliced`.
// tensor_indices should have the same size as sliced.dim():
// - tensor_indices[i] = NULL if we should not index `sliced` at dim i
// - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t.
Value* emitMultidimSlicing(
const SourceRange& loc,
Value* sliceable,
const List<Expr>& subscript_exprs) {
if (!sliceable->type()->isSubtypeOf(DynamicType::get())) {
throw ErrorReport(loc)
<< "Unsupported operation: attempted to use multidimensional "
<< "indexing on a non-tensor type.";
}
std::vector<Value*> tensor_indices;
std::tie(sliceable, tensor_indices) =
emitIntAndSliceIndexing(loc, sliceable, subscript_exprs);
if (tensor_indices.empty()) {
// XXX: Might need to at::alias this when we support mutability
return sliceable;
}
return emitIndex(loc, sliceable, tensor_indices);
}
// Desugars slice syntactic sugar tensor[begin:end] -> tensor.slice(begin,
// end).
Value* emitBasicSlice(
const SourceRange& loc,
Value* sliceable,
const List<Expr>& subscript_exprs) {
JIT_ASSERT(subscript_exprs.size() == 1);
JIT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
auto slice_exp = SliceExpr(subscript_exprs[0]);
c10::optional<int64_t> maybe_dim;
if (sliceable->type()->isSubtypeOf(DynamicType::get())) {
// If the sliceable object is a tensor, specify a default dimension
maybe_dim = 0;
}
return emitSlice(loc, sliceable, maybe_dim, slice_exp);
}
int64_t getTupleIndexVal(const SourceRange& loc,
const TupleTypePtr& tuple_type,
Value * idx_val,
bool allow_out_of_bounds) {
int64_t index;
at::optional<IValue> ivalue = toIValue(idx_val);
if (ivalue && ivalue->isInt()) {
index = ivalue->to<int64_t>();
} else {
throw ErrorReport(loc)
<< "tuple indices must be integer constants";
}
// set index to be positive to simplify logic in runtime
int64_t adj_index = index;
int64_t tuple_len = tuple_type->elements().size();
if (index < 0) {
adj_index = tuple_len + index;
}
if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
throw ErrorReport(loc)
<< "Tuple index out of range. Tuple is length " << tuple_len
<< " and index is " << index;
}
return adj_index;
}
Value* emitTupleIndex(const SourceRange& loc,
Value * tuple_val,
Value * idx_val) {
auto tuple_typ = tuple_val->type()->cast<TupleType>();
auto adj_index = getTupleIndexVal(loc, tuple_typ, idx_val, /*allow_out_of_bounds*/false);
return graph->insertNode(
graph->createTupleIndex(tuple_val, adj_index))->output();
}
Value* emitTupleSlice(const SourceRange& loc,
const NamedValue& tuple_val,
const NamedValue& beg_val,
const at::optional<NamedValue>& end_val) {
auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
int64_t beg = getTupleIndexVal(loc, tuple_type, beg_val.value(*graph), /*allow_out_of_bounds*/true);
int64_t end;
int64_t tuple_len = tuple_type->elements().size();
if (end_val) {
end = getTupleIndexVal(loc, tuple_type, end_val->value(*graph), true);
} else {
end = tuple_len;
}
// slicing does not throw out of bounds errors
end = std::min(std::max((int64_t)0, end), tuple_len);
beg = std::min(std::max((int64_t)0, beg), tuple_len);
return graph->insertNode(
graph->createTupleSlice(tuple_val.value(*graph), beg, end))->output();
}
Value* emitSubscript(const Subscript& subscript) {
return emitSubscript(
subscript.range(),
emitExpr(subscript.value()),
subscript.subscript_exprs());
}
Value* emitSubscript(
const SourceRange& loc,
Value* sliceable,
const List<Expr>& subscript_exprs) {
if (subscript_exprs.size() != 1) {
return emitMultidimSlicing(loc, sliceable, subscript_exprs);
}
if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
return emitBasicSlice(loc, sliceable, subscript_exprs);
} else {
return emitBasicGather(loc, sliceable, subscript_exprs);
}
}
// Desugars gather syntactic sugar foo[i]
Value* emitBasicGather(
const SourceRange& loc,
Value* gatherable,
const List<Expr>& subscript_exprs) {
JIT_ASSERT(subscript_exprs.size() == 1);
if (gatherable->type()->kind() == TypeKind::ListType) {
// if it's a list, emit a regular index selection op
auto* idx = emitExpr(subscript_exprs[0]);
return emitBuiltinCall(
loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
} else if (gatherable->type()->isSubtypeOf(DynamicType::get())) {
return emitMultidimSlicing(loc, gatherable, subscript_exprs);
} else if (auto tuple_type = gatherable->type()->cast<TupleType>()) {
auto* idx = emitExpr(subscript_exprs[0]);
return emitTupleIndex(loc, gatherable, idx);
} else {
throw ErrorReport(loc)
<< "Indexing only supported on lists, tensors, and tuples.";
}
}
};
static const std::unordered_map<std::string, std::string> &builtin_cast_methods() {
static std::unordered_map<std::string, std::string> builtin_cast_methods = {
{"byte", "_cast_Byte"},
{"char", "_cast_Char"},
{"double", "_cast_Double"},
{"float", "_cast_Float"},
{"int", "_cast_Int"},
{"long", "_cast_Long"},
{"short", "_cast_Short"},
{"half", "_cast_Half"}
};
return builtin_cast_methods;
}
// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
// callable value that will resolve to foo(x, y, z) when called.
std::shared_ptr<SugaredValue> SimpleValue::attr(SourceRange loc, Method & m, const std::string& field) {
// Allow method-style casts on Tensor types. e.g. x.int()
if (value->type()->isSubtypeOf(DynamicType::get())) {
if (builtin_cast_methods().count(field)) {
return std::make_shared<BuiltinFunction>(
Symbol::aten(builtin_cast_methods().at(field)),
NamedValue(loc, "self", value));
}
// functions that are just direct property lookups on tensor
// must be registered as prim::<name>(Tensor t) -> <return_type>
static const std::unordered_set<std::string> fields = {
"dtype",
"device",
"shape",
"is_cuda",
"requires_grad",
};
if (fields.count(field)) {
auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value});
return std::make_shared<SimpleValue>(r);
}
}
if (getValue()->type()->isSubtypeOf(NumberType::get())) {
throw ErrorReport(loc) << "Cannot call methods on numbers";
}
return std::make_shared<BuiltinFunction>(
Symbol::aten(field), NamedValue(loc, "self", value));
}
std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
std::unordered_map<Value*, Value*> value_map;
auto value_map_func = [&](Value* v) { return value_map.at(v); };
JIT_ASSERT(callee.inputs().size() == inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
value_map[callee.inputs()[i]] = inputs[i];
}
for (auto* node : callee.nodes()) {
auto* new_node =
g.insertNode(g.createClone(node, value_map_func));
for (size_t i = 0; i < node->outputs().size(); ++i) {
value_map[node->outputs()[i]] = new_node->outputs()[i];
}
}
std::vector<Value*> outputs;
for (auto* output : callee.outputs()) {
outputs.push_back(value_map_func(output));
}
return outputs;
}
void defineMethodsInModule(std::shared_ptr<Module> m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, SugaredValuePtr self) {
JIT_ASSERT(definitions.size() == resolvers.size());
auto resolver_it = resolvers.begin();
std::vector<Method*> methods;
std::unordered_map<std::string, Method*> function_table;
for(Def def : definitions) {
const std::string& name = def.name().name();
auto resolver = *resolver_it++;
JIT_ASSERT(resolver);
if(!self) {
// if self is defined, then these are methods and do not go into the global namespace
// otherwise, they get defined together so we add them to the function table
// so the methods can see each other
resolver = [resolver, &function_table](
const std::string& name,
Method& m,
const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
auto it = function_table.find(name);
if (it != function_table.end()) {
return std::make_shared<MethodValue>(nullptr, *it->second);
}
return resolver(name, m, loc);
};
}
auto creator = [def, resolver, self](Method& method) {
JIT_ASSERT(resolver);
to_ir(def, resolver, self, method);
};
Method& method = m->create_method(name, creator);
function_table[name] = &method;
methods.push_back(&method);
}
for(Method* method : methods) {
method->ensure_defined();
}
didFinishEmitModule(m);
}
const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
static std::unordered_map<std::string, TypePtr> map = {
{"Tensor", DynamicType::get()},
{"int", IntType::get()},
{"float", FloatType::get()},
{"bool", BoolType::get()},
{"str", StringType::get()},
// technically this is not a python type but we need it when
// parsing serialized methods that use implicit converions to Scalar
{"number", NumberType::get()},
};
return map;
}
const std::unordered_map<std::string, std::function<TypePtr(Subscript)>> &subscript_to_type_fns() {
static std::unordered_map<std::string, std::function<TypePtr(Subscript)>> map = {
{"Tuple", [](Subscript subscript) -> TypePtr {
std::vector<TypePtr> subscript_expr_types;
for (auto expr : subscript.subscript_exprs()) {
subscript_expr_types.push_back(parseTypeFromExpr(expr));
}
return TupleType::create(subscript_expr_types);
}},
{"List", [](Subscript subscript) -> TypePtr {
if (subscript.subscript_exprs().size() != 1) {
throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
}
auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
return ListType::create(elem_type);
}},
{"Optional", [](Subscript subscript) -> TypePtr {
if (subscript.subscript_exprs().size() != 1) {
throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
}
auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
return OptionalType::create(elem_type);
}},
};
return map;
}
bool isTorch(Expr expr) {
return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
}
// gets the base type name given namespaces where the types live
// turns torch.Tensor -> Tensor, X -> X
c10::optional<std::string> parseBaseTypeName(Expr expr) {
switch (expr.kind()) {
case TK_VAR: {
return Var(expr).name().name();
}
case '.': {
auto select = Select(expr);
const std::string& name = select.selector().name();
if (isTorch(select.value()) && name == "Tensor")
return "Tensor";
}
}
return at::nullopt;
}
TypePtr parseTypeFromExpr(Expr expr) {
if (expr.kind() == TK_SUBSCRIPT) {
auto subscript = Subscript(expr);
auto value_name = parseBaseTypeName(subscript.value());
if (!value_name) {
throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
}
if (!subscript_to_type_fns().count(*value_name)) {
throw ErrorReport(subscript.range()) << "Unknown type constructor " << *value_name;
}
return subscript_to_type_fns().at(*value_name)(subscript);
} else if (auto name = parseBaseTypeName(expr)) {
auto itr = ident_to_type_lut().find(*name);
if (itr != ident_to_type_lut().end()) {
return itr->second;
}
throw ErrorReport(expr) << "Unknown type name " << *name;
}
throw ErrorReport(expr.range()) << "Expression of type " << kindToString(expr.kind())
<< " cannot be used in a type expression";
}
c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(Expr expr) {
if (expr.kind() != TK_SUBSCRIPT)
return c10::nullopt;
auto subscript = Subscript(expr);
if (subscript.value().kind() != TK_VAR)
return c10::nullopt;
auto var = Var(subscript.value());
auto subscript_exprs = subscript.subscript_exprs();
// handle the case where the BroadcastingList is wrapped in a Optional type
if(var.name().name() == "Optional") {
auto broadcast_list = handleBroadcastList(subscript_exprs[0]);
if (broadcast_list) {
TypePtr opt_type = OptionalType::create(broadcast_list->first);
return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
} else {
return c10::nullopt;
}
} else if (var.name().name().find("BroadcastingList") != 0) {
return c10::nullopt;
}
if (subscript_exprs.size() != 1)
throw ErrorReport(subscript.subscript_exprs().range())
<< "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
auto typ = subscript_exprs[0];
auto len = var.name().name().substr(strlen("BroadcastingList"));
if (typ.kind() != TK_VAR)
throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
auto value_name = Var(typ).name().name();
if (value_name != "float" && value_name != "int")
throw ErrorReport(subscript.value().range()) << "Broadcastable lists only supported for int or float";
auto elem_ptr = ident_to_type_lut().find(value_name);
JIT_ASSERT(elem_ptr != ident_to_type_lut().end());
TypePtr list_ptr = ListType::create(elem_ptr->second);
Parser const_parser(len);
auto constant = const_parser.parseConst();
if (!constant.isIntegral() || constant.asIntegral() <= 0) {
throw ErrorReport(subscript.subscript_exprs().range())
<< "subscript of Broadcastable list must be positive integer";
}
auto len_v = constant.asIntegral();
return std::pair<TypePtr, int32_t>(list_ptr, len_v);
}
void defineMethodsInModule(std::shared_ptr<Module> m, const std::string& source, Resolver resolver, SugaredValuePtr self) {
Parser p(source);
std::vector<Def> definitions;
std::vector<Resolver> resolvers;
while (p.lexer().cur().kind != TK_EOF) {
auto def = Def(p.parseFunction(/*is_method=*/bool(self)));
definitions.push_back(def);
resolvers.push_back(resolver);
}
defineMethodsInModule(m, definitions, resolvers, self);
}
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
SourceRange loc,
Method& m,
c10::optional<size_t> size_hint) {
static const auto make_simple_value = [](Value* v) -> std::shared_ptr<SugaredValue> {
return std::make_shared<SimpleValue>(v);
};
if(value->type()->kind() == TypeKind::TupleType) {
auto outputs = createTupleUnpack(value);
return fmap(outputs, make_simple_value);
} else if (value->type()->kind() == TypeKind::ListType) {
if (!size_hint) {
throw ErrorReport(loc) << "cannot statically infer the expected size of a list in this context";
}
auto graph = value->owningGraph();
Node *unpack = graph->insertNode(graph->createListUnpack(value, *size_hint));
return fmap(unpack->outputs(), make_simple_value);
}
throw ErrorReport(loc) << value->type()->str() << " cannot be used as a tuple";
}
} // namespace script
} // namespace jit
} // namespace torch