| #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ | 
 | #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ | 
 |  | 
 | #include "caffe2/core/context.h" | 
 | #include "caffe2/core/logging.h" | 
 | #include "caffe2/core/operator.h" | 
 | #include "caffe2/operators/create_scope_op.h" | 
 |  | 
 | namespace caffe2 { | 
 |  | 
 | template <class Context> | 
 | class ONNXWhileOp final : public Operator<Context> { | 
 |  public: | 
 |   explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws) | 
 |       : Operator<Context>(operator_def, ws), | 
 |         parent_ws_(ws), | 
 |         has_trip_count_( | 
 |             this->template GetSingleArgument<int64_t>("has_trip_count", 0)), | 
 |         has_cond_(this->template GetSingleArgument<int64_t>("has_cond", 0)), | 
 |         save_scopes_( | 
 |             this->template GetSingleArgument<int64_t>("save_scopes", 0)), | 
 |         disable_scopes_( | 
 |             this->template GetSingleArgument<int64_t>("disable_scopes", 0)), | 
 |         num_loop_carried_deps_(this->template GetSingleArgument<int64_t>( | 
 |             "num_loop_carried_deps", | 
 |             -1)) { | 
 |     CAFFE_ENFORCE( | 
 |         this->template HasSingleArgumentOfType<NetDef>("body"), | 
 |         "body net must be specified in ONNXWhile operator"); | 
 |     if (disable_scopes_) { | 
 |       CAFFE_ENFORCE(!save_scopes_, "Cannot save scopes when disable_scopes=True"); | 
 |     } | 
 |     body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef()); | 
 |     static int64_t counter = -1; | 
 |     if (!body_net_def_.has_name()) { | 
 |       if (counter == -1) { | 
 |         ++counter; | 
 |         body_net_def_.set_name("loop_net"); | 
 |       } else { | 
 |         ++counter; | 
 |         body_net_def_.set_name("loop_net." + c10::to_string(counter)); | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   USE_OPERATOR_CONTEXT_FUNCTIONS; | 
 |  | 
 |   bool RunOnDevice() { | 
 |     return DispatchHelper<TensorTypes<int, bool, long>>::call(this, Input(1)); | 
 |   } | 
 |  | 
 |   // Operator | 
 |   //  Inputs: max trip count, condition, initial loop-carried dependencies | 
 |   //  Outputs: Final loop-carried dependencies, scan_outputs | 
 |   // Body | 
 |   //  Inputs: iteration number, condition, loop-carried dependencies | 
 |   //  Outputs: condition, loop-carried dependencies, scan_outputs | 
 |   template <typename CondVarType> | 
 |   bool DoRunWithType() { | 
 |     // Clear workspaces from the previous invocations of the loop | 
 |     // and setup a local scope for the first iteration | 
 |     ws_stack_.clear(); | 
 |     auto loop_ws = !disable_scopes_ ? ws_stack_.pushForwardWorkspace(parent_ws_).get() : parent_ws_; | 
 |  | 
 |     constexpr int64_t num_inputs_before_lcds = 2; | 
 |     // First input is the maximumt trip count. Second input is the condition | 
 |     // variable (for the first iteration). The rest of the inputs are | 
 |     // loop-carried dependencies. | 
 |     int64_t num_loop_carried_deps; | 
 |     if (num_loop_carried_deps_ != -1) { | 
 |       num_loop_carried_deps = num_loop_carried_deps_; | 
 |     } else { | 
 |       num_loop_carried_deps = InputSize() - num_inputs_before_lcds; | 
 |     } | 
 |     int64_t max_trip_count = *Input(0).template data<int64_t>(); | 
 |     const bool first_iter_condition = *Input(1).template data<CondVarType>(); | 
 |  | 
 |     scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_, num_loop_carried_deps); | 
 |  | 
 |     // Body graph has 1+N+K outputs: recalculated condition variable, N | 
 |     // loop-carried dependencies, and K scan_outputs | 
 |     int num_scan_outputs = | 
 |         scope_->net()->external_output().size() - num_loop_carried_deps - 1; | 
 |  | 
 |     CAFFE_ENFORCE_GE( | 
 |         num_scan_outputs, | 
 |         0, | 
 |         "Body graph must have N+K outputs, where N is the number " | 
 |         "of loop-carried dependencies and K is the number of scan " | 
 |         "outputs"); | 
 |  | 
 |     // Copy initial loop-carried dependencies | 
 |     for (int i = 0; i < num_loop_carried_deps; ++i) { | 
 |       scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds)); | 
 |     } | 
 |  | 
 |     // Initialize iteration variable | 
 |     scope_->set_iteration(0ll); | 
 |  | 
 |     // Initialize input condition variable | 
 |     scope_->template set_input_condition<CondVarType>(first_iter_condition); | 
 |  | 
 |     auto valid_iter_num = [this, max_trip_count](int64_t i) { | 
 |       if (has_trip_count_) { | 
 |         return i < max_trip_count; | 
 |       } else { | 
 |         return true; | 
 |       } | 
 |     }; | 
 |  | 
 |     auto condition_true = | 
 |         [this, first_iter_condition](int64_t i, bool cond_value) { | 
 |           if (has_cond_) { | 
 |             if (i == 0) { | 
 |               return (bool)first_iter_condition; | 
 |             } else { | 
 |               return cond_value; | 
 |             } | 
 |           } else { | 
 |             return true; | 
 |           } | 
 |         }; | 
 |  | 
 |     // Allocate scan_outputs for zero-iteration case | 
 |     for (int i = 0; i < num_scan_outputs; ++i) { | 
 |       Output(i + num_loop_carried_deps)->Resize(0); | 
 |       Output(i + num_loop_carried_deps)->template mutable_data<int32_t>(); | 
 |     } | 
 |  | 
 |     // Use this to keep track of the sizes of the scan outputs and validate | 
 |     // they're the same across iterations. | 
 |     std::vector<std::vector<int64_t>> scan_outputs_sizes; | 
 |  | 
 |     Workspace *cur_ws = nullptr; | 
 |     bool cur_output_condition = false; | 
 |  | 
 |     while (true) { | 
 |       int64_t itr = scope_->iteration(); | 
 |       if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) { | 
 |         if (!scope_->net()->Run()) { | 
 |           return false; | 
 |         } | 
 |  | 
 |         cur_ws = scope_->workspace(); | 
 |         cur_output_condition = scope_->template output_condition<CondVarType>(); | 
 |         if (save_scopes_) { | 
 |           loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get(); | 
 |           scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_, num_loop_carried_deps); | 
 |         } | 
 |  | 
 |         // Copy forward loop-carried dependencies | 
 |         for (int i = 0; i < num_loop_carried_deps; ++i) { | 
 |           Blob* b = cur_ws->GetBlob( | 
 |               scope_->net()->external_output()[i + 1]); | 
 |           const Tensor& t = b->template Get<Tensor>(); | 
 |           scope_->lcd_tensor(i)->CopyFrom(t); | 
 |         } | 
 |         // Copy out scan_outputs | 
 |         for (int i = 0; i < num_scan_outputs; ++i) { | 
 |           int net_output_idx = i + 1 + num_loop_carried_deps; | 
 |           const Tensor& scan_output = | 
 |               cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx]) | 
 |                   ->template Get<Tensor>(); | 
 |           auto* scan_output_target = Output(i + num_loop_carried_deps); | 
 |           if (itr == 0) { | 
 |             auto dims = scan_output.sizes().vec(); | 
 |             scan_outputs_sizes.push_back(dims); | 
 |             dims.insert(dims.begin(), 1); | 
 |             scan_output_target->Resize(dims); | 
 |             scan_output_target->CopyFrom(scan_output); | 
 |           } else { | 
 |             auto dims = scan_output.sizes().vec(); | 
 |             CAFFE_ENFORCE_EQ( | 
 |                 dims, | 
 |                 scan_outputs_sizes[i], | 
 |                 "Size of scan output changed across iterations"); | 
 |             dims.insert(dims.begin(), itr); | 
 |             scan_output_target->Extend(1, 100); | 
 |  | 
 |             int64_t timestep_size = 1; | 
 |             for (const int64_t t : scan_outputs_sizes[i]) { | 
 |               timestep_size *= t; | 
 |             } | 
 |  | 
 |             const void* src_data = scan_output.raw_data(); | 
 |             auto& sot_meta = scan_output_target->dtype(); | 
 |             void* dst_data = | 
 |                 (char*)scan_output_target->raw_mutable_data(sot_meta) + | 
 |                 timestep_size * scan_output.itemsize() * itr; | 
 |             memcpy(dst_data, src_data, timestep_size * scan_output.itemsize()); | 
 |           } | 
 |         } | 
 |         scope_->set_iteration(itr + 1ll); | 
 |         scope_->template set_input_condition<CondVarType>(cur_output_condition); | 
 |       } else { | 
 |         break; | 
 |       } | 
 |     } | 
 |  | 
 |     // Copy out final loop-carried dependencies | 
 |     for (int i = 0; i < num_loop_carried_deps; ++i) { | 
 |       Output(i)->CopyFrom(*scope_->lcd_tensor(i)); | 
 |     } | 
 |  | 
 |     return true; | 
 |   } | 
 |  | 
 |  private: | 
 |   class LocalScope { | 
 |    public: | 
 |     LocalScope( | 
 |         Workspace *loop_ws, | 
 |         const NetDef& body_net_def, size_t num_lcds) : loop_ws_(loop_ws){ | 
 |       CAFFE_ENFORCE(loop_ws_, | 
 |           "Failed to initialize local loop workspace"); | 
 |  | 
 |       // Create loop-carried deps in Workspace | 
 |       lcd_tensors_.clear(); | 
 |       for (int i = 2; i < num_lcds + 2; ++i) { | 
 |         Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i)); | 
 |         Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType()); | 
 |         lcd_tensors_.push_back(t); | 
 |       } | 
 |       // First output is the iteration variable | 
 |       auto* iteration_var_blob = loop_ws_->CreateBlob( | 
 |           body_net_def.external_input(0)); | 
 |       iteration_var_ = | 
 |           BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType()); | 
 |  | 
 |       input_condition_var_ = BlobGetMutableTensor( | 
 |           loop_ws_->CreateBlob(body_net_def.external_input(1)), | 
 |           Context::GetDeviceType()); | 
 |  | 
 |       auto* condition_var_blob = | 
 |           loop_ws_->CreateBlob(body_net_def.external_output(0)); | 
 |       condition_var_ = | 
 |           BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType()); | 
 |       condition_var_->Resize(1); | 
 |       condition_var_->template mutable_data<bool>(); | 
 |  | 
 |       body_net_ = loop_ws_->GetNet(body_net_def.name()); | 
 |       if (!body_net_) { | 
 |         body_net_ = loop_ws_->CreateNet(body_net_def, true); | 
 |       } | 
 |       CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet"); | 
 |     } | 
 |  | 
 |     NetBase* net() const { | 
 |       return body_net_; | 
 |     } | 
 |  | 
 |     Workspace* workspace() const { | 
 |       return loop_ws_; | 
 |     } | 
 |  | 
 |     int64_t iteration() const { | 
 |       auto* iteration_var_ptr = | 
 |           iteration_var_->template mutable_data<int64_t>(); | 
 |       return *iteration_var_ptr; | 
 |     } | 
 |  | 
 |     Tensor* lcd_tensor(int idx) { | 
 |       return lcd_tensors_[idx]; | 
 |     } | 
 |  | 
 |     void set_iteration(int64_t itr) { | 
 |       iteration_var_->Resize(); | 
 |       auto* iteration_var_ptr = | 
 |           iteration_var_->template mutable_data<int64_t>(); | 
 |       *iteration_var_ptr = itr; | 
 |     } | 
 |  | 
 |     template <typename CondVarType> | 
 |     void set_input_condition(bool cond_value) { | 
 |       input_condition_var_->Resize(1); | 
 |       auto* input_condition_var_ptr = | 
 |           input_condition_var_->template mutable_data<CondVarType>(); | 
 |       *input_condition_var_ptr = cond_value; | 
 |     } | 
 |  | 
 |     template <typename CondVarType> | 
 |     bool output_condition() const { | 
 |       auto* condition_var_ptr = | 
 |           condition_var_->template mutable_data<CondVarType>(); | 
 |       return *condition_var_ptr; | 
 |     } | 
 |  | 
 |    private: | 
 |     Workspace *loop_ws_; | 
 |  | 
 |     NetBase* body_net_; // owned by a workspace | 
 |     Tensor* iteration_var_; | 
 |     Tensor* input_condition_var_; | 
 |     Tensor* condition_var_; | 
 |  | 
 |     std::vector<Tensor*> lcd_tensors_; | 
 |   }; | 
 |  | 
 |   NetDef body_net_def_; | 
 |   Workspace* parent_ws_; | 
 |   detail::WorkspaceStack ws_stack_; | 
 |  | 
 |   bool has_trip_count_; | 
 |   bool has_cond_; | 
 |   bool save_scopes_; | 
 |   bool disable_scopes_; | 
 |   int64_t num_loop_carried_deps_; | 
 |  | 
 |   std::shared_ptr<LocalScope> scope_; | 
 | }; | 
 |  | 
 | } // namespace caffe2 | 
 |  | 
 | #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H |