| #include <c10/util/irange.h> |
| #include <fusion.h> |
| #include <ir_all_nodes.h> |
| #include <ir_builder.h> |
| #include <mutator.h> |
| |
| #include <vector> |
| |
| namespace torch { |
| namespace jit { |
| namespace fuser { |
| namespace cuda { |
| |
| void OptOutMutator::mutate(Statement* s) { |
| Statement::mutatorDispatch(this, s); |
| } |
| |
| void OptOutMutator::mutate(Expr* e) { |
| Expr::mutatorDispatch(this, e); |
| } |
| |
| void OptOutMutator::mutate(Val* v) { |
| Val::mutatorDispatch(this, v); |
| } |
| |
| void OptOutMutator::registerMutation(Val* val, Val* mutation) { |
| bool val_is_ns = val->vtype() == ValType::NamedScalar; |
| bool mutation_is_ns = mutation->vtype() == ValType::NamedScalar; |
| bool val_is_scalar = val->vtype() == ValType::Scalar; |
| bool mutation_is_scalar = mutation->vtype() == ValType::Scalar; |
| TORCH_INTERNAL_ASSERT( |
| mutation->dtype() == val->dtype() && |
| (mutation->vtype() == val->vtype() || |
| ((val_is_ns && mutation_is_scalar) || |
| (mutation_is_ns && val_is_scalar))), |
| "Mutations are not allowed to change types, tried to go from: (", |
| val->vtype(), |
| ", ", |
| val->dtype(), |
| ") to: (", |
| mutation->vtype(), |
| ", ", |
| mutation->dtype(), |
| ")"); |
| mutations[val] = mutation; |
| } |
| |
| void OptOutMutator::mutate(Bool* b) {} |
| |
| void OptOutMutator::mutate(Double* d) {} |
| |
| void OptOutMutator::mutate(Int* i) {} |
| |
| void OptOutMutator::mutate(ComplexDouble* c) {} |
| |
| void OptOutMutator::mutate(NamedScalar* ns) {} |
| |
| void OptOutMutator::mutate(IterDomain* id) { |
| Val* start = maybeMutated(id->start()); |
| Val* extent = maybeMutated(id->extent()); |
| Val* expanded_extent = nullptr; |
| if (id->hasExpandedExtent()) { |
| expanded_extent = maybeMutated(id->expandedExtent()); |
| } |
| Val* stop_offset = maybeMutated(id->stopOffset()); |
| if (start->sameAs(id->start()) && extent->sameAs(id->extent()) && |
| (!id->hasExpandedExtent() || |
| expanded_extent->sameAs(id->expandedExtent())) && |
| stop_offset->sameAs(id->stopOffset())) { |
| return; |
| } |
| registerMutation( |
| id, |
| IterDomainBuilder(id) |
| .start(start) |
| .extent(extent) |
| .stop_offset(stop_offset) |
| .expanded_extent(expanded_extent) |
| .build()); |
| } |
| |
| void OptOutMutator::mutate(TensorDomain* td) { |
| bool mutated = false; |
| |
| auto updateIdVec = [&](const std::vector<IterDomain*>& ids) { |
| std::vector<IterDomain*> updated_ids; |
| for (auto id : ids) { |
| auto updated_id = maybeMutated(id)->as<IterDomain>(); |
| updated_ids.push_back(updated_id); |
| if (!updated_id->sameAs(id)) { |
| mutated = true; |
| } |
| } |
| return updated_ids; |
| }; |
| |
| std::vector<IterDomain*> root_dom = updateIdVec(td->getRootDomain()); |
| std::vector<IterDomain*> rfactor_dom = td->hasRFactor() |
| ? updateIdVec(td->getMaybeRFactorDomain()) |
| : std::vector<IterDomain*>(); |
| std::vector<IterDomain*> domain = updateIdVec(td->domain()); |
| |
| if (!mutated) { |
| return; |
| } |
| |
| Val* mutated_val = IrBuilder::create<TensorDomain>( |
| td->container(), root_dom, rfactor_dom, domain, td->contiguity()); |
| registerMutation(td, mutated_val); |
| } |
| |
| void OptOutMutator::mutate(TensorView* tv) { |
| TensorDomain* td = maybeMutated(tv->domain())->as<TensorDomain>(); |
| if (!tv->domain()->sameAs(td)) { |
| tv->setDomain(td); |
| } |
| // Don't register tv mutations as we just want to update the TD |
| } |
| |
| void OptOutMutator::mutate(kir::Predicate*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| |
| void OptOutMutator::mutate(kir::TensorIndex*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| |
| void OptOutMutator::mutate(FullOp* fop) { |
| Val* out = maybeMutated(fop->output(0)); |
| Val* fill_value = maybeMutated(fop->getFillValue()); |
| |
| if (out->sameAs(fop->output(0))) { |
| return; |
| } |
| auto container = fop->container(); |
| container->removeExpr(fop); |
| IrBuilder::create<FullOp>(container, out, fill_value, fop->dtype()); |
| } |
| |
| void OptOutMutator::mutate(ARangeOp* aop) { |
| Val* out = maybeMutated(aop->output(0)); |
| |
| if (out->sameAs(aop->output(0))) { |
| return; |
| } |
| auto container = aop->container(); |
| container->removeExpr(aop); |
| IrBuilder::create<ARangeOp>( |
| container, |
| out, |
| aop->start(), |
| aop->end(), |
| aop->step(), |
| aop->dtype(), |
| aop->getLinearLogicalIndex()); |
| } |
| |
| void OptOutMutator::mutate(EyeOp* eop) { |
| Val* out = maybeMutated(eop->output(0)); |
| |
| if (out->sameAs(eop->output(0))) { |
| return; |
| } |
| auto container = eop->container(); |
| container->removeExpr(eop); |
| IrBuilder::create<EyeOp>( |
| container, out, eop->dtype(), eop->getIndex1(), eop->getIndex2()); |
| } |
| |
| void OptOutMutator::mutate(UnaryOp* uop) { |
| Val* out = maybeMutated(uop->out()); |
| Val* in = maybeMutated(uop->in()); |
| |
| if (out->sameAs(uop->out()) && in->sameAs(uop->in())) { |
| return; |
| } |
| auto container = uop->container(); |
| auto uop_type = uop->getUnaryOpType(); |
| container->removeExpr(uop); |
| IrBuilder::create<UnaryOp>(container, uop_type, out, in); |
| } |
| |
| void OptOutMutator::mutate(BinaryOp* bop) { |
| Val* out = maybeMutated(bop->out()); |
| Val* lhs = maybeMutated(bop->lhs()); |
| Val* rhs = maybeMutated(bop->rhs()); |
| |
| if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs()) { |
| return; |
| } |
| |
| auto container = bop->container(); |
| auto bop_type = bop->getBinaryOpType(); |
| container->removeExpr(bop); |
| IrBuilder::create<BinaryOp>(container, bop_type, out, lhs, rhs); |
| } |
| |
| void OptOutMutator::mutate(TernaryOp* top) { |
| Val* out = maybeMutated(top->out()); |
| Val* in1 = maybeMutated(top->in1()); |
| Val* in2 = maybeMutated(top->in2()); |
| Val* in3 = maybeMutated(top->in3()); |
| |
| if (out == top->out() && in1 == top->in1() && in2 == top->in2() && |
| in3 == top->in3()) { |
| return; |
| } |
| |
| auto container = top->container(); |
| auto top_type = top->getTernaryOpType(); |
| container->removeExpr(top); |
| IrBuilder::create<TernaryOp>(container, top_type, out, in1, in2, in3); |
| } |
| |
| void OptOutMutator::mutate(RNGOp* rop) { |
| Val* out = maybeMutated(rop->output(0)); |
| auto& parameters = rop->getParameters(); |
| std::vector<Val*> mutated_parameters; |
| for (auto v : parameters) { |
| mutated_parameters.emplace_back(maybeMutated(v)); |
| } |
| |
| if (out == rop->output(0) && mutated_parameters == parameters) { |
| return; |
| } |
| |
| auto container = rop->container(); |
| auto rop_type = rop->getRNGOpType(); |
| container->removeExpr(rop); |
| IrBuilder::create<RNGOp>( |
| container, |
| rop_type, |
| out, |
| rop->dtype(), |
| mutated_parameters, |
| rop->getRNGOffset(), |
| rop->getPhiloxIndex()); |
| } |
| |
| void OptOutMutator::mutate(ReductionOp* rop) { |
| Val* out = maybeMutated(rop->out()); |
| Val* in = maybeMutated(rop->in()); |
| Val* init = rop->init(); |
| if (out->sameAs(rop->out()) && in->sameAs(rop->in()) && |
| init->sameAs(rop->init())) { |
| return; |
| } |
| |
| auto container = rop->container(); |
| auto rop_type = rop->getReductionOpType(); |
| container->removeExpr(rop); |
| IrBuilder::create<ReductionOp>( |
| container, rop_type, init, out, in, rop->isAllreduce()); |
| } |
| |
| void OptOutMutator::mutate(GroupedReductionOp* rop) { |
| bool is_same = true; |
| |
| std::vector<Val*> outputs; |
| for (auto out : rop->outputs()) { |
| auto maybe_mutated = maybeMutated(out); |
| is_same = is_same && maybe_mutated->sameAs(out); |
| outputs.push_back(maybe_mutated); |
| } |
| |
| std::vector<Val*> inputs; |
| for (auto in : rop->inputs()) { |
| auto maybe_mutated = maybeMutated(in); |
| is_same = is_same && maybe_mutated->sameAs(in); |
| inputs.push_back(maybe_mutated); |
| } |
| |
| std::vector<Val*> init_vals; |
| for (auto init : rop->initVals()) { |
| auto maybe_mutated = maybeMutated(init); |
| is_same = is_same && maybe_mutated->sameAs(init); |
| init_vals.push_back(maybe_mutated); |
| } |
| |
| if (is_same) { |
| return; |
| } |
| |
| auto container = rop->container(); |
| const auto& rop_types = rop->getReductionOpTypes(); |
| container->removeExpr(rop); |
| IrBuilder::create<GroupedReductionOp>( |
| container, rop_types, init_vals, outputs, inputs, rop->isAllreduce()); |
| } |
| |
| namespace { |
| inline bool compareOptional(Val* a, Val* b) { |
| if (!a || !b) { |
| return (!a && !b); |
| } |
| return a->sameAs(b); |
| } |
| |
| } // namespace |
| |
| void OptOutMutator::mutate(WelfordOp* wop) { |
| Val* out_avg = maybeMutated(wop->outAvg()); |
| Val* out_var = maybeMutated(wop->outVar()); |
| Val* out_N = maybeMutated(wop->outN()); |
| |
| Val* in_avg = maybeMutated(wop->inAvg()); |
| Val* in_var = wop->inVar() ? maybeMutated(wop->inVar()) : nullptr; |
| Val* in_N = maybeMutated(wop->inN()); |
| |
| Val* init_avg = wop->initAvg() ? maybeMutated(wop->initAvg()) : nullptr; |
| Val* init_var = wop->initVar() ? maybeMutated(wop->initVar()) : nullptr; |
| Val* init_N = maybeMutated(wop->initN()); |
| |
| const bool out_compare = out_avg->sameAs(wop->outAvg()) && |
| out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN()); |
| const bool in_compare = in_avg->sameAs(wop->inAvg()) && |
| compareOptional(in_var, wop->inVar()) && in_N->sameAs(wop->inN()); |
| const bool init_compare = compareOptional(init_avg, wop->initAvg()) && |
| compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN()); |
| |
| if (out_compare && init_compare && in_compare) { |
| return; |
| } |
| |
| auto container = wop->container(); |
| container->removeExpr(wop); |
| IrBuilder::create<WelfordOp>( |
| container, |
| out_avg, |
| out_var, |
| out_N, |
| in_avg, |
| in_var, |
| in_N, |
| init_avg, |
| init_var, |
| init_N, |
| wop->isAllreduce()); |
| } |
| |
| void OptOutMutator::mutate(GroupedWelfordOp* wop) { |
| bool is_same = true; |
| |
| std::vector<WelfordTriplet> output_vals; |
| for (const auto& out : wop->outputVals()) { |
| auto maybe_mutated = |
| out.transform([&](Val* val) { return maybeMutated(val); }); |
| is_same = is_same && maybe_mutated.sameAs(out); |
| output_vals.push_back(maybe_mutated); |
| } |
| |
| std::vector<WelfordTriplet> input_vals; |
| for (const auto& inp : wop->inputVals()) { |
| auto maybe_mutated = |
| inp.transform([&](Val* val) { return maybeMutated(val); }); |
| is_same = is_same && maybe_mutated.sameAs(inp); |
| input_vals.push_back(maybe_mutated); |
| } |
| |
| std::vector<WelfordTriplet> init_vals; |
| for (const auto& init : wop->initVals()) { |
| auto maybe_mutated = |
| init.transform([&](Val* val) { return maybeMutated(val); }); |
| is_same = is_same && maybe_mutated.sameAs(init); |
| init_vals.push_back(maybe_mutated); |
| } |
| |
| if (is_same) { |
| return; |
| } |
| |
| auto container = wop->container(); |
| container->removeExpr(wop); |
| IrBuilder::create<GroupedWelfordOp>( |
| container, output_vals, input_vals, init_vals, wop->isAllreduce()); |
| } |
| |
| void OptOutMutator::mutate(MmaOp* mma) { |
| Val* out = maybeMutated(mma->out()); |
| Val* in_a = maybeMutated(mma->inA()); |
| Val* in_b = maybeMutated(mma->inB()); |
| Val* init = mma->init(); |
| |
| if (out->sameAs(mma->out()) && in_a->sameAs(mma->inA()) && |
| in_b->sameAs(mma->inB())) { |
| return; |
| } |
| |
| auto container = mma->container(); |
| auto options = mma->options(); |
| container->removeExpr(mma); |
| C10_UNUSED auto new_mma = |
| IrBuilder::create<MmaOp>(container, out, in_a, in_b, init, options); |
| } |
| |
| void OptOutMutator::mutate(LoadStoreOp* ldst) { |
| Val* out = maybeMutated(ldst->out()); |
| Val* in = maybeMutated(ldst->in()); |
| auto op_type = ldst->opType(); |
| |
| if (out->sameAs(ldst->out()) && in->sameAs(ldst->in())) { |
| return; |
| } |
| |
| auto container = ldst->container(); |
| container->removeExpr(ldst); |
| IrBuilder::create<LoadStoreOp>(container, op_type, out, in); |
| } |
| |
| void OptOutMutator::mutate(BroadcastOp* bop) { |
| Val* out = maybeMutated(bop->out()); |
| Val* in = maybeMutated(bop->in()); |
| |
| if (out->sameAs(bop->out()) && in->sameAs(bop->in())) { |
| return; |
| } |
| |
| auto container = bop->container(); |
| auto flags = bop->getBroadcastDimFlags(); |
| container->removeExpr(bop); |
| IrBuilder::create<BroadcastOp>(container, out, in, flags); |
| } |
| |
| void OptOutMutator::mutate(TransposeOp* top) { |
| TensorView* out = maybeMutated(top->out())->as<TensorView>(); |
| TensorView* in = maybeMutated(top->in())->as<TensorView>(); |
| |
| if (out->sameAs(top->out()) && in->sameAs(top->in())) { |
| return; |
| } |
| |
| auto container = top->container(); |
| auto new2old = top->new2old(); |
| container->removeExpr(top); |
| IrBuilder::create<TransposeOp>(container, out, in, new2old); |
| } |
| |
| void OptOutMutator::mutate(ExpandOp* eop) { |
| bool is_same = true; |
| |
| TensorView* out = maybeMutated(eop->out())->as<TensorView>(); |
| is_same = is_same && out->sameAs(eop->out()); |
| TensorView* in = maybeMutated(eop->in())->as<TensorView>(); |
| is_same = is_same && in->sameAs(eop->in()); |
| |
| std::vector<Val*> expanded_extents; |
| expanded_extents.reserve(eop->expanded_extents().size()); |
| for (auto expanded_extent : eop->expanded_extents()) { |
| expanded_extents.push_back(maybeMutated(expanded_extent)); |
| if (!expanded_extents.back()->sameAs(expanded_extent)) { |
| is_same = false; |
| } |
| } |
| |
| if (is_same) { |
| return; |
| } |
| |
| auto container = eop->container(); |
| container->removeExpr(eop); |
| IrBuilder::create<ExpandOp>(container, out, in, expanded_extents); |
| } |
| |
| void OptOutMutator::mutate(ShiftOp* sop) { |
| Val* out = maybeMutated(sop->out())->asVal(); |
| Val* in = maybeMutated(sop->in())->asVal(); |
| |
| if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { |
| return; |
| } |
| |
| auto offsets = sop->offsets(); |
| auto pad_width = sop->padWidth(); |
| auto container = sop->container(); |
| container->removeExpr(sop); |
| IrBuilder::create<ShiftOp>(container, out, in, offsets, pad_width); |
| } |
| |
| void OptOutMutator::mutate(GatherOp* op) { |
| Val* out = maybeMutated(op->out())->asVal(); |
| Val* in = maybeMutated(op->in())->asVal(); |
| |
| if (out->sameAs(op->out()) && in->sameAs(op->in())) { |
| return; |
| } |
| |
| auto window_shape = op->windowShape(); |
| auto pad_width = op->padWidth(); |
| auto container = op->container(); |
| container->removeExpr(op); |
| IrBuilder::create<GatherOp>(container, out, in, window_shape, pad_width); |
| } |
| |
| void OptOutMutator::mutate(ViewAsScalar* vop) { |
| TensorView* out = maybeMutated(vop->out())->as<TensorView>(); |
| TensorView* in = maybeMutated(vop->in())->as<TensorView>(); |
| |
| if (out->sameAs(vop->out()) && in->sameAs(vop->in())) { |
| return; |
| } |
| |
| auto container = vop->container(); |
| container->removeExpr(vop); |
| IrBuilder::create<ViewAsScalar>( |
| container, out, in, vop->vector_id(), vop->index()); |
| } |
| |
| void OptOutMutator::mutate(ViewOp* vop) { |
| TensorView* out = maybeMutated(vop->out())->as<TensorView>(); |
| TensorView* in = maybeMutated(vop->in())->as<TensorView>(); |
| |
| if (out->sameAs(vop->out()) && in->sameAs(vop->in())) { |
| return; |
| } |
| |
| auto container = vop->container(); |
| container->removeExpr(vop); |
| IrBuilder::create<ViewOp>(container, out, in); |
| } |
| |
| void OptOutMutator::mutate(Split* s) { |
| IterDomain* ot = maybeMutated(s->outer())->as<IterDomain>(); |
| IterDomain* inr = maybeMutated(s->inner())->as<IterDomain>(); |
| IterDomain* in = maybeMutated(s->in())->as<IterDomain>(); |
| Val* fact = maybeMutated(s->factor())->as<Val>(); |
| Val* start_offset = maybeMutated(s->startOffset()); |
| Val* stop_offset = maybeMutated(s->stopOffset()); |
| |
| if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && |
| in->sameAs(s->in()) && areEqualScalars(fact, s->factor()) && |
| start_offset->sameAs(s->startOffset()) && |
| stop_offset->sameAs(s->stopOffset())) { |
| return; |
| } |
| |
| auto container = s->container(); |
| auto inner_split = s->innerSplit(); |
| container->removeExpr(s); |
| C10_UNUSED auto new_node = IrBuilder::create<Split>( |
| container, ot, inr, in, fact, inner_split, start_offset, stop_offset); |
| } |
| |
| void OptOutMutator::mutate(Merge* m) { |
| IterDomain* ot = maybeMutated(m->out())->as<IterDomain>(); |
| IterDomain* otr = maybeMutated(m->outer())->as<IterDomain>(); |
| IterDomain* in = maybeMutated(m->inner())->as<IterDomain>(); |
| |
| if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && |
| in->sameAs(m->inner())) { |
| return; |
| } |
| |
| auto container = m->container(); |
| container->removeExpr(m); |
| C10_UNUSED auto new_node = IrBuilder::create<Merge>(container, ot, otr, in); |
| } |
| |
| void OptOutMutator::mutate(Swizzle2D* m) { |
| IterDomain* outx = maybeMutated(m->outX())->as<IterDomain>(); |
| IterDomain* outy = maybeMutated(m->outY())->as<IterDomain>(); |
| |
| IterDomain* inx = maybeMutated(m->inX())->as<IterDomain>(); |
| IterDomain* iny = maybeMutated(m->inY())->as<IterDomain>(); |
| |
| auto swizzle_type = m->swizzleType(); |
| |
| if (outx->sameAs(m->outX()) && outy->sameAs(m->outY()) && |
| inx->sameAs(m->inX()) && iny->sameAs(m->inY())) { |
| return; |
| } |
| auto container = m->container(); |
| container->removeExpr(m); |
| FusionGuard::getCurFusion()->removeExpr(m); |
| C10_UNUSED auto new_node = IrBuilder::create<Swizzle2D>( |
| container, outx, outy, inx, iny, swizzle_type); |
| } |
| |
| void OptOutMutator::mutate(kir::Allocate*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::BlockSync*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::GridSync*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::CpAsyncWait*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::CpAsyncCommit*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::InitMagicZero*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::UpdateMagicZero*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::ForLoop*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::IfThenElse*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::GridReduction*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::GroupedGridReduction*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::GridBroadcast*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::GridWelford*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::GroupedGridWelford*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::AllocateFusedReduction*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::Swizzle2DInt*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::PairSelect*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| void OptOutMutator::mutate(kir::IntPair*) { |
| TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); |
| } |
| |
| void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { |
| container->removeExpr(expr); |
| } |
| } // namespace cuda |
| } // namespace fuser |
| } // namespace jit |
| } // namespace torch |