blob: 4143d0d77548caeab1a9f1e5f1978339715a66dc [file] [log] [blame]
#include <ATen/ATen.h>
#include <torch/csrc/jit/alias_info.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/edit_distance.h>
#include <queue>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace {
using OperatorMap =
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
struct OperatorRegistry {
private:
std::mutex lock;
OperatorMap operators;
// list of operators whose schema have not yet been parsed, and must
// be registered before any call to lookup an operator
std::vector<std::shared_ptr<Operator>> to_register;
// Those two maps are used to implement lookupByLiteral, which is needed for
// the n->match(...) calls. Basically, every function schema is assigned a
// unique string you can use to match it. However, parsing those strings or
// comparing and hashing them character by character would be very slow, so we
// use a trick here! Every string literal in your program is guaranteed to
// have static storage duration and so its address won't change at runtime.
// This allows us to memoize answers for every pointer, which is done by the
// operators_by_sig_literal map. Still, this map is initially empty, and so we
// still need to do the complete string matching at the first time, which is
// implemented by performing a lookup in the operators_by_sig map.
std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
std::unordered_map<const char*, std::shared_ptr<Operator>>
operators_by_sig_literal;
// XXX - caller must be holding lock
void registerPendingOperators() {
for (const auto& op : to_register) {
Symbol sym = Symbol::fromQualString(op->schema().name());
operators[sym].push_back(op);
operators_by_sig[canonicalSchemaString(op->schema())] = op;
}
to_register.clear();
}
public:
void registerOperator(Operator&& op) {
std::lock_guard<std::mutex> guard(lock);
to_register.push_back(std::make_shared<Operator>(std::move(op)));
}
const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
std::lock_guard<std::mutex> guard(lock);
registerPendingOperators();
auto it = operators_by_sig_literal.find(name);
if (it == operators_by_sig_literal.end()) {
auto op_ptr_it =
operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
// Handy debugging code that dumps all operators we know about on mismatch
#if 0
if (op_ptr_it == operators_by_sig.end()) {
for (auto & entry : operators_by_sig) {
std::cout << entry.first << std::endl;
}
}
#endif
TORCH_CHECK(
op_ptr_it != operators_by_sig.end(),
"Couldn't find an operator for ",
name,
". Do you have to update a set of hardcoded JIT ops?");
it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
}
return it->second;
}
const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
std::lock_guard<std::mutex> guard(lock);
registerPendingOperators();
static std::vector<std::shared_ptr<Operator>> empty;
auto it = operators.find(name);
if (it != operators.end())
return it->second;
return empty;
}
std::vector<Symbol> findSimilarOperators(Symbol input_op) {
std::lock_guard<std::mutex> guard(lock);
registerPendingOperators();
using EntryPair = std::pair<int64_t, Symbol>;
auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) {
return lhs.first > rhs.first;
};
std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
rankings(cmp);
static constexpr size_t MAX_EDIT_DIST = 2u;
for (const auto& op : operators) {
auto edit_dist = script::ComputeEditDistance(
input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
if (edit_dist <= MAX_EDIT_DIST) {
rankings.emplace(edit_dist, op.first);
}
}
std::vector<Symbol> ret;
while (!rankings.empty()) {
ret.push_back(rankings.top().second);
rankings.pop();
}
return ret;
}
const std::vector<std::shared_ptr<Operator>> getAllOperators() {
std::lock_guard<std::mutex> guard(lock);
registerPendingOperators();
std::vector<std::shared_ptr<Operator>> values;
values.clear();
for (auto & kv : operators) {
values.insert(values.end(), kv.second.begin(), kv.second.end());
}
return values;
}
};
OperatorRegistry& getRegistry() {
static OperatorRegistry r;
return r;
}
bool printerHasSpecialCaseFor(Symbol sym) {
using namespace at;
// WARNING: by adding a value to this set, you are asserting
// that you have also added special handling of this symbol to
// the python_print.cpp. Not adding handling will cause import and export
// of modules with this new operator to fail. This is only required
// for operators without schema. Prefer registering your operator with
// schema to editing this list here. These cases should only be things
// that require special handling because they do not fit normal schema
const static std::unordered_set<Symbol> handled = {
prim::Constant,
prim::Uninitialized,
prim::fork,
prim::ListConstruct,
prim::DictConstruct,
prim::ListUnpack,
prim::Print,
prim::PythonOp,
prim::TupleConstruct,
prim::TupleIndex,
prim::TupleSlice,
prim::TupleUnpack,
prim::CreateObject,
prim::GetAttr,
prim::SetAttr,
prim::CallFunction,
prim::isinstance,
prim::unchecked_cast,
};
// WARNING: by adding a value to this set, you are asserting that your
// primitive is only ever added during optimization and does not need
// to be correctly printed for export (a process that happens before
// optimization passes run)
const static std::unordered_set<Symbol> unneeded = {
c10::onnx::Reshape, // only used in onnx
c10::onnx::Shape, // only used in onnx
prim::AutogradZero, // temporarily inserted by autograd
prim::AutogradAnyNonZero, // temporarily inserted by autograd
prim::AutogradAdd, // temporarily inserted by autograd
prim::ConstantChunk, // optimization pass adds it
prim::DifferentiableGraph, // optimization pass adds it
prim::BroadcastSizes, // optimization pass (fuser) adds it
prim::ChunkSizes, // optimization pass (fuser) adds it
prim::Drop, // used in interpreter only
prim::FusedConcat, // optimization pass adds it
prim::FusionGroup, // optimization pass adds it
prim::Load, // used in interpreter only
prim::MMTreeReduce, // used as an optimization
prim::MMBatchSide, // used as an optimization
prim::Store, // used in interpreter only
prim::profile, // used in interpreter only
};
// These namespaces are required to have Python printers unless
// otherwise noted in unneeded.
const static std::unordered_set<Symbol> required_namespaces = {
c10::namespaces::prim,
c10::namespaces::aten,
c10::namespaces::onnx,
};
return handled.count(sym) || unneeded.count(sym) ||
!required_namespaces.count(sym.ns());
}
} // anonymous namespace
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
using namespace at;
// WARNING: by adding a case to this list, you are asserting that you have
// added a case for the unschematized node in AliasDb::analyze
const static std::unordered_set<Symbol> handled = {
prim::If,
prim::Loop,
prim::FusionGroup,
prim::DifferentiableGraph,
prim::Constant,
prim::Uninitialized,
prim::DictConstruct,
prim::ListConstruct,
prim::TupleConstruct,
prim::AutogradZero,
prim::FusedConcat,
prim::GradOf,
prim::MMTreeReduce,
prim::MMBatchSide,
prim::BroadcastSizes,
prim::ChunkSizes,
prim::Function,
prim::TupleUnpack,
prim::TupleIndex,
prim::TupleSlice,
prim::ListUnpack,
prim::PythonOp,
prim::ConstantChunk,
prim::BroadcastingChunk,
prim::fork,
prim::CreateObject,
prim::AutogradAdd,
prim::GetAttr,
prim::SetAttr,
prim::profile,
prim::Print,
prim::CallFunction,
prim::CallMethod,
aten::wait,
prim::isinstance,
prim::unchecked_cast,
};
// Operators that should not be used by alias analysis
const static std::unordered_set<Symbol> purposefully_not_handled = {
prim::Load,
prim::Store,
prim::Drop,
at::onnx::Reshape,
at::onnx::Shape,
prim::AutogradAdd,
};
return handled.count(symbol) || purposefully_not_handled.count(symbol);
}
void registerOperator(Operator&& op) {
if (op.schema().is_varret()) {
Symbol s = Symbol::fromQualString(op.schema().name());
if (!printerHasSpecialCaseFor(s)) {
AT_ERROR(
"Missing special case in python printer for non-schematized"
" operator ",
op.schema().name(),
". File a bug to add a case for this operator.\n");
}
if (!aliasAnalysisHasSpecialCaseFor(s) &&
op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) {
AT_ERROR(
"Missing special case in alias analysis for non-schematized"
" operator ",
op.schema().name(),
". File a bug to add a case for this operator.\n");
}
if (aliasAnalysisHasSpecialCaseFor(s) &&
op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) {
AT_ERROR(
"The operator ",
op.schema().name(),
" is special cased and cannot use explicit alias analysis.");
}
}
getRegistry().registerOperator(std::move(op));
}
const std::vector<std::shared_ptr<Operator>> getAllOperators() {
return getRegistry().getAllOperators();
}
const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
return getRegistry().getOperators(name);
}
std::shared_ptr<Operator> findOperatorFor(const c10::OperatorName& full_name) {
for (const auto& op : getRegistry().getOperators(Symbol::fromQualString(full_name.name))) {
if (op->schema().overload_name() == full_name.overload_name) {
return op;
}
}
return nullptr;
}
std::vector<Symbol> findSimilarOperators(Symbol input_op) {
return getRegistry().findSimilarOperators(input_op);
}
std::shared_ptr<Operator> getOperatorForLiteral(const char* signature) {
return getRegistry().lookupByLiteral(signature);
}
std::string canonicalSchemaString(const FunctionSchema& schema) {
std::ostringstream out;
out << schema.name();
out << "(";
bool seen_kwarg_only = false;
for (size_t i = 0; i < schema.arguments().size(); ++i) {
if (i > 0)
out << ", ";
if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
out << "*, ";
seen_kwarg_only = true;
}
const auto& arg = schema.arguments()[i];
out << arg.type()->str() << " " << arg.name();
}
out << ") -> ";
if (schema.returns().size() == 1) {
out << schema.returns().at(0).type()->str();
} else if (schema.returns().size() > 1) {
out << "(";
for (size_t i = 0; i < schema.returns().size(); ++i) {
if (i > 0)
out << ", ";
out << schema.returns()[i].type()->str();
}
out << ")";
}
return out.str();
}
} // namespace jit
} // namespace torch