blob: 7a325601d70c54b01cc39f340626996ad7c2ff9b [file] [log] [blame]
#include <ir_builder.h>
#include <ir_utils.h>
#include <root_domain_map.h>
#include <transform_iter.h>
#include <grouped_reduction.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
// Return if ref and other are transformed in the same way.
bool hasMatchingTransformations(TensorView* ref, TensorView* other) {
std::unordered_map<IterDomain*, IterDomain*> ref_2_other;
for (const auto i : c10::irange(ref->getRootDomain().size())) {
ref_2_other.emplace(
ref->getRootDomain().at(i), other->getRootDomain().at(i));
}
auto replay =
BestEffortReplay(
other->domain()->domain(), ref->domain()->domain(), ref_2_other)
.getReplay();
for (const auto i : c10::irange(ref->nDims())) {
auto ref_id = ref->axis(i);
auto other_id = other->axis(i);
auto it = replay.find(ref_id);
if (it == replay.end() || it->second != other_id) {
return false;
}
}
return true;
}
// Validate grouping of reductions and return a new max producer position
void validateReductionGrouping(
const std::vector<Val*>& inputs,
const std::vector<Val*>& outputs) {
TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size());
TORCH_INTERNAL_ASSERT(!inputs.empty());
auto fusion = dynamic_cast<Fusion*>(outputs[0]->container());
TORCH_INTERNAL_ASSERT(
fusion != nullptr, "Grouping of reductions must be done within a Fusion");
ExactRootDomainMap exact_map(fusion);
// Pick the first output TV as a reference and compare it with the
// rest. Do not allow grouping if any mismatch is detected.
auto ref_tv = outputs[0]->as<TensorView>();
const auto ref_domain = ref_tv->getRootDomain();
const auto num_root_dims = ref_domain.size();
const auto num_dims = ref_tv->nDims();
const auto ref_ca_pos = ref_tv->getComputeAtPosition();
for (const auto i : c10::irange(inputs.size())) {
auto output_tv = outputs.at(i)->as<TensorView>();
const auto& output_domain = output_tv->getRootDomain();
if (ref_tv == output_tv) {
continue;
}
TORCH_INTERNAL_ASSERT(
output_domain.size() == num_root_dims,
"Invalid grouped reduction due to mismatched number of root dimensions. "
"Expected: ",
num_root_dims,
". Detected: ",
output_domain.size(),
". Invalid output tensor: ",
output_tv->toString());
TORCH_INTERNAL_ASSERT(
output_tv->nDims() == num_dims,
"Invalid grouped reduction due to mismatched number of dimensions. "
"Expected: ",
num_dims,
". Detected: ",
output_tv->nDims(),
". Invalid output tensor: ",
output_tv->toString());
for (const auto i : c10::irange(num_root_dims)) {
auto ref_id = ref_domain.at(i);
auto output_id = output_domain.at(i);
// If an IterDomain is broadcast, require the other
// corresponding IterDomains are also broadcast. This may not be
// necessary but not completely certain.
TORCH_INTERNAL_ASSERT(
ref_id->isBroadcast() == output_id->isBroadcast(),
"Invalid grouped reduction due to mismatched broadcast root domains. ",
"Reference domain: ",
ref_id->toString(),
". Mismatched domain: ",
output_id->toString(),
". Invalid tensor: ",
output_tv->toString());
if (ref_id->isBroadcast()) {
continue;
}
TORCH_INTERNAL_ASSERT(
ref_id->isReduction() == output_id->isReduction(),
"Invalid grouped reduction due to mismatched reduction root domains. ",
"Reference domain: ",
ref_id->toString(),
". Mismatched domain: ",
output_id->toString(),
". Invalid tensor: ",
output_tv->toString());
TORCH_INTERNAL_ASSERT(
exact_map.areMapped(ref_id, output_id) || ref_id->sameAs(output_id),
"Invalid grouped reduction due to mismatched root domains. ",
"Reference domain: ",
ref_id->toString(),
". Mismatched domain: ",
output_id->toString(),
". Invalid tensor: ",
output_tv->toString());
}
TORCH_INTERNAL_ASSERT(
hasMatchingTransformations(ref_tv, output_tv),
"Invalid grouped reduction due to mismatched transformations. ",
"Reference tensor: ",
ref_tv->toString(),
". Mismatched tensor: ",
output_tv->toString());
// Must have the same computeAt position
TORCH_INTERNAL_ASSERT(
output_tv->getComputeAtPosition() == ref_ca_pos,
"Invalid grouped reduction due to mismatched computeAt position. ",
"Reference tensor: ",
ref_tv->toString(),
". Mismatched tensor: ",
output_tv->toString());
}
// Must not have any data dependency from outputs to inputs
const auto all_dep_vals = DependencyCheck::getAllValsBetween(
{outputs.begin(), outputs.end()}, inputs);
if (!all_dep_vals.empty()) {
std::stringstream ss;
ss << "Invalid dependency:";
for (auto val : all_dep_vals) {
ss << " " << val->toString();
}
TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str());
}
}
} // namespace
void groupReductions(const std::vector<TensorView*>& reduction_outputs) {
TORCH_CHECK(!reduction_outputs.empty(), "No tensor is given");
auto container = reduction_outputs[0]->container();
const auto num_reductions = reduction_outputs.size();
std::vector<BinaryOpType> op_types(num_reductions);
std::vector<Val*> init_vals(num_reductions);
std::vector<Val*> outputs(num_reductions);
std::vector<Val*> inputs(num_reductions);
for (const auto i : c10::irange(num_reductions)) {
auto reduction_out = reduction_outputs.at(i);
TORCH_CHECK(
reduction_out->definition() != nullptr,
"Invalid tensor to group: ",
reduction_out->toString(),
". Definition not found");
auto rop = dynamic_cast<ReductionOp*>(reduction_out->definition());
TORCH_CHECK(
rop != nullptr,
"Invalid tensor to group: ",
reduction_out->toString(),
". Not an output of a ReductionOp: ",
reduction_out->definition()->toString());
// Fused reduction is only enabled during the lowering, so at this
// point it should be false.
TORCH_INTERNAL_ASSERT(
!rop->isAllreduce(), "Invalid ReductionOp: ", rop->toString());
op_types.at(i) = rop->getReductionOpType();
init_vals.at(i) = rop->init();
outputs.at(i) = rop->out();
inputs.at(i) = rop->in();
}
validateReductionGrouping(inputs, outputs);
IrBuilder::create<GroupedReductionOp>(
container, op_types, init_vals, outputs, inputs);
for (auto output : ir_utils::filterByType<TensorView>(outputs)) {
output->updateMaxProducerPosition();
}
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch