blob: 3e1454619727e8bce9e5dc987b249d3f34c537c4 [file] [log] [blame]
#include <torch/csrc/jit/symbolic_script.h>
namespace torch {
namespace jit {
namespace {
std::mutex lock;
const std::vector<std::string> functions = {
R"(
def mul(self, other):
def backward(grad_output):
grad_self = (grad_output * other).sum_to_size(self.size())
grad_other = (grad_output * self).sum_to_size(other.size())
return grad_self, grad_other
return self * other, backward
def adaptive_avg_pool2d(self,
output_size: List[int]):
def backward(grad_output):
grad_self = torch.adaptive_avg_pool2d_backward(grad_output, self)
return grad_self, None
return torch.adaptive_avg_pool2d(self, output_size), backward
)"};
std::unordered_map<std::string, GradientPair> schema_to_graphs;
// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
// should be compiled only once and saved in Operator structure.
// This should be done along with merging into native_functions.yaml.
std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
} // anonymous namespace
std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
AT_CHECK(
closure->node()->kind() == prim::TupleConstruct,
"closure must be a literal tuple construct");
Value* fn = closure->node()->inputs().at(0);
Value* context = closure->node()->inputs().at(1);
AT_CHECK(
fn->node()->kind() == prim::Function,
"closure tuple must contain a prim::Function");
return std::make_pair(fn->node()->g(attr::Subgraph), context);
}
Argument originalReturnType(const TupleTypePtr& tup) {
AT_CHECK(tup->elements().size() > 1);
if (tup->elements().size() == 2)
return Argument("", tup->elements().at(0));
std::vector<TypePtr> types = tup->elements().vec();
types.pop_back();
return Argument("", TupleType::create(std::move(types)));
}
void loadModule(const std::shared_ptr<script::Module>& module) {
for (const auto& method_ : module->get_methods()) {
const auto& method = method_.value();
GradientPair pair;
pair.forward = method->graph();
// lookup the backward function
Node* forward_tuple = pair.forward->outputs().at(0)->node();
if (forward_tuple->kind() != prim::TupleConstruct) {
throw script::ErrorReport(forward_tuple->getSourceLocation())
<< "gradient must return literal a tuple";
}
Value* context;
std::tie(pair.backward, context) =
extractClosure(forward_tuple->inputs().back());
// do surgery on the forward function to remove the closure tuple and
// replace it with the context variable:
// backward = (<lambda>, context_tuple)
// return original, backward
// -----
// return original, context_tuple
std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
new_inputs.back() = context;
Value* new_tuple =
pair.forward->appendNode(pair.forward->createTuple(new_inputs))
->output();
pair.forward->eraseOutput(0);
pair.forward->registerOutput(new_tuple);
forward_tuple->destroy();
// derive schema from original function's schema:
const FunctionSchema& loaded_schema = method->getSchema();
FunctionSchema actual_schema(
Symbol::aten(loaded_schema.name()),
loaded_schema.arguments(),
{originalReturnType(new_tuple->type()->expect<TupleType>())});
std::string key = canonicalSchemaString(actual_schema);
schema_to_graphs[key] = std::move(pair);
}
}
void loadFunctions() {
for (const std::string& str : functions) {
auto cu = std::make_shared<script::Module>();
script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
loadModule(cu);
}
}
c10::optional<GradientPair> gradientInfoForSchema(
const FunctionSchema& schema) {
std::lock_guard<std::mutex> guard(lock);
if (schema_to_graphs.size() == 0) {
loadFunctions();
}
auto cache_it = cached_gradient_pairs.find(&schema);
if (cache_it != cached_gradient_pairs.end()) {
return cache_it->second;
} else {
auto schema_str = canonicalSchemaString(schema);
auto sym_script_it = schema_to_graphs.find(schema_str);
if (sym_script_it != schema_to_graphs.end()) {
cached_gradient_pairs.emplace_hint(
cache_it, &schema, sym_script_it->second);
return sym_script_it->second;
}
}
return c10::nullopt;
}
bool hasGradientInfoForSchema(const FunctionSchema& schema) {
return gradientInfoForSchema(schema).has_value();
}
} // namespace jit
} // namespace torch