| #pragma once |
| |
| #include <torch/csrc/jit/ir.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| struct IfView { |
| explicit IfView(Node* node) : node_(node) { |
| AT_ASSERT(node->kind() == ::c10::prim::If); |
| } |
| Value* cond() const { |
| return node_->input(0); |
| } |
| Block* thenBlock() const { |
| return node_->blocks().at(0); |
| } |
| Block* elseBlock() const { |
| return node_->blocks().at(1); |
| } |
| ArrayRef<Value*> thenOutputs() const { |
| return thenBlock()->outputs(); |
| } |
| ArrayRef<Value*> elseOutputs() const { |
| return elseBlock()->outputs(); |
| } |
| ArrayRef<Value*> outputs() const { |
| return node_->outputs(); |
| } |
| Node* node() const { |
| return node_; |
| } |
| operator Node*() const { |
| return node_; |
| } |
| |
| void permuteOutputs(const std::vector<size_t>& new_output_order) { |
| node_->permuteOutputs(new_output_order); |
| thenBlock()->permuteOutputs(new_output_order); |
| elseBlock()->permuteOutputs(new_output_order); |
| } |
| |
| private: |
| Node* node_; |
| }; |
| |
| struct LoopView { |
| explicit LoopView(Node* node) : node_(node) { |
| AT_ASSERT( |
| node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop); |
| } |
| Block* bodyBlock() const { |
| return node_->blocks().at(0); |
| } |
| Value* cond() const { |
| return node_->input(0); |
| } |
| Value* maxTripCount() const { |
| return node_->input(0); |
| } |
| Value* inputCond() const { |
| return node_->input(1); |
| } |
| Value* nextCond() const { |
| return bodyBlock()->outputs().at(0); |
| } |
| Value* currentTripCount() const { |
| return bodyBlock()->inputs().at(0); |
| } |
| ArrayRef<Value*> carriedInputs() const { |
| // skip trip count and cond |
| return node_->inputs().slice(2); |
| } |
| ArrayRef<Value*> carriedInputsWithCond() const { |
| // skip trip count and cond |
| return node_->inputs().slice(1); |
| } |
| ArrayRef<Value*> carriedOutputs() const { |
| return node_->outputs(); |
| } |
| ArrayRef<Value*> bodyCarriedInputs() const { |
| // skip trip count and cond |
| return bodyBlock()->inputs().slice(1); |
| } |
| ArrayRef<Value*> bodyCarriedOutputs() const { |
| return bodyBlock()->outputs().slice(1); |
| } |
| Node* node() const { |
| return node_; |
| } |
| operator Node*() const { |
| return node_; |
| } |
| |
| void permuteLoopCarried(const std::vector<size_t>& new_output_order) { |
| node_->permuteOutputs(new_output_order); |
| // skip trip count and cond |
| node_->permuteInputs(adjustIndices(2, new_output_order)); |
| auto adjusted_block_order = adjustIndices(1, new_output_order); |
| bodyBlock()->permuteOutputs(adjusted_block_order); |
| bodyBlock()->permuteInputs(adjusted_block_order); |
| } |
| |
| void replaceMaxTripCount(Value* new_max_trip_count) { |
| node_->replaceInput(0, new_max_trip_count); |
| } |
| void replaceInputCondition(Value* new_input_condition) { |
| node_->replaceInput(1, new_input_condition); |
| } |
| |
| // our way of encoding loops makes them difficult to turn back into python |
| // syntax. we have to check properties of the condition and trip count inputs |
| // to figure out which one it initially was. ModifiedLoops are not directly |
| // mappable to either For or While |
| enum LoopType { While, For, ModifiedLoop }; |
| |
| LoopType loopType() { |
| auto trip_count = toIValue(maxTripCount()); |
| auto cond_input = toIValue(inputCond()); |
| auto cond_next = toIValue(nextCond()); |
| |
| bool condition_is_always_true = |
| cond_input && cond_input->toBool() && cond_next && cond_next->toBool(); |
| bool trip_count_is_specified = !trip_count || // trip is not a constant |
| trip_count->toInt() != |
| std::numeric_limits<int64_t>::max() || // it is a constant but not |
| // the default one |
| currentTripCount()->uses().size() > |
| 0; // it is actually being used in the body. |
| |
| if (condition_is_always_true) { |
| // if the trip count was not specified this was a user-written while True: |
| return trip_count_is_specified ? For : While; |
| } else { |
| if (trip_count_is_specified) { |
| return ModifiedLoop; |
| } |
| return While; |
| } |
| } |
| |
| private: |
| Node* node_; |
| |
| // adjust index_ordering by adding indices 0 - thorugh adjust, and |
| // incrementing all existing inputs by adjust |
| static std::vector<size_t> adjustIndices( |
| size_t adjust, |
| const std::vector<size_t>& index_ordering) { |
| std::vector<size_t> adjusted; |
| adjusted.reserve(adjust + index_ordering.size()); |
| for (size_t i = 0; i < adjust; ++i) { |
| adjusted.push_back(i); |
| } |
| for (auto index : index_ordering) { |
| adjusted.push_back(index + adjust); |
| } |
| return adjusted; |
| } |
| }; |
| } // namespace jit |
| } // namespace torch |