blob: 90bafd7c4bb7ecaf459b096a3ced6fd3f675a280 [file] [log] [blame]
// in memory description of all ATen Ops similar to Caffe2 schema
// once C10 exists this can be removed, or stubbed out, but we need
// it now to implement correct semantic checking for script
#pragma once
#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/script/function_schema_parser.h>
#include <torch/csrc/jit/operator_options.h>
#include <ATen/core/stack.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/ATen.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/interned_strings.h>
#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
struct Node;
using ::c10::Symbol;
using ::c10::FunctionSchema;
using OperationCreator = Operation (*)(const Node*);
/*
* Note: JIT relies on Operator instances having static lifetime, because
* it for example stores a non-owning FunctionSchema* pointer in the Node class,
* which points to the function shema stored in the Operator instance.
* Also, jit::Operator is meant to store more operator related information like
* symbolic derivatives, which also requires them to have static lifetime
* so that changes to symbolic derivatives are remembered.
*
* Now, currently, the c10 operator library doesn't store jit::Operator instances,
* but we use a listener pattern that notifies JIT about changes in the
* c10 operator library and then registers jit::Operator instances to the JIT
* operator registry, acting as wrappers to the c10 operators.
*
* However, that results in code duplication as JIT and c10 will likely get
* their own mechanisms for storing derivatives and other operator related
* information, and all of this would have to be wrapped from c10 into JIT.
*
* We should consider merging the JIT and c10 registries, moving jit::Operator
* to c10 and storing these jit::Operator instances in the c10 operator library
* instead, allowing us to have these mechanisms only implemented once.
* However, the current jit::Operator implementation has additional features
* like OperationCreator that aren't needed in c10 (they're only used for
* prim ops like If/Else or While which wouldn't be in the c10 operator library),
* and which depend on other JIT features which we don't want to move to c10
* (notably jit/ir.h). We might, however, be able, to split jit::Operator into
* a c10::Operator with the core features and a jit::Operator that adds the
* JIT-only features like OperationCreator, and then use c10::Operator in the
* c10 operator library.
*/
struct TORCH_API Operator {
Operator(c10::OperatorHandle opHandle, Operation operation)
: schema_(std::make_shared<FunctionSchema>(opHandle.schema())),
op_(std::make_shared<Operation>(std::move(operation))),
c10Handle_(opHandle),
options_(c10Handle_->options()) {}
Operator(
FunctionSchema schema,
OperationCreator op_creator,
c10::OperatorOptions options = c10::OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
Operator(
const std::string& schema,
OperationCreator op_creator,
c10::OperatorOptions options = c10::OperatorOptions())
: schema_string_(schema),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
// Helper constructor to register `op` to run
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
// This is accomplished by marking the schema varargs and having no required
// arguments. This is used for things like prim::While or prim::If that can
// take a number of different valid input types and lengths.
Operator(
Symbol name,
OperationCreator op_creator,
c10::OperatorOptions options = c10::OperatorOptions())
: Operator(
varArgSchemaWithName(name),
std::move(op_creator),
std::move(options)) {}
Operator(
Symbol name,
Operation op,
c10::OperatorOptions options = c10::OperatorOptions())
: Operator(
varArgSchemaWithName(name),
std::move(op),
std::move(options)) {}
Operator(
FunctionSchema schema,
Operation op,
c10::OperatorOptions options = c10::OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
Operator(
const std::string& schema,
int(*op)(Stack&),
c10::OperatorOptions options = c10::OperatorOptions())
: schema_string_(schema),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
bool matches(const Node* node) const;
Operation getOperation(const Node* node = nullptr) const {
if (op_) {
return *op_;
}
AT_ASSERT(node != nullptr);
return op_creator_(node);
}
const FunctionSchema& schema() const {
// we lazily parse schema initialized from strings so that
// we do less work during static operator registration
if (!schema_) {
schema_ =
std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
schema_string_ = c10::nullopt;
}
return *schema_;
}
bool isC10Op() const {
return c10Handle_.has_value();
}
c10::AliasAnalysisKind aliasAnalysisKind() const {
return options_.aliasAnalysis();
}
private:
static FunctionSchema varArgSchemaWithName(Symbol name) {
return FunctionSchema(
name,
"",
{},
{},
/*is_vararg*/ true,
/*is_varret*/ true);
}
mutable c10::optional<std::string> schema_string_;
// cannot use c10::optional because windows has issues that require an
// assignment operator to be generated cannot use std::unique_ptr because
// initializer lists of Operators end up copying the Operator
mutable std::shared_ptr<FunctionSchema> schema_;
// Essentially a variant<Operation, OperationCreator>.
// NB: std::function has a default state (where it == nullptr).
std::shared_ptr<Operation> op_;
OperationCreator op_creator_;
c10::optional<c10::OperatorHandle> c10Handle_;
c10::OperatorOptions options_;
};
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
TORCH_API const std::vector<std::shared_ptr<Operator>> getAllOperators();
TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
Symbol name);
std::shared_ptr<Operator> findOperatorFor(const Node* node);
const Operator& getOperatorFor(const Node* node);
inline Operation getOperation(const Node* node) {
// note: getOperatorFor ensures that getOperatorFor(node).matches(node) ==
// true so the call to selectVariant is always valid.
return getOperatorFor(node).getOperation(node);
}
TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);
TORCH_API void registerOperator(Operator&& op);
// XXX: this function is meant to be used with string literals only!
Operator& sig(const char* signature_literal);
struct OperatorSet {
OperatorSet(std::initializer_list<const char*> sig_literals);
// XXX: Returns a nullptr if no Operator in the set matches n
Operator* find(const Node* n) const;
private:
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
};
} // namespace jit
} // namespace torch