|  | #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ | 
|  | #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ | 
|  |  | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/logging.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/operators/create_scope_op.h" | 
|  | #include "c10/util/irange.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <class Context> | 
|  | class ONNXWhileOp final : public Operator<Context> { | 
|  | public: | 
|  | explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws), | 
|  | parent_ws_(ws), | 
|  | has_trip_count_( | 
|  | this->template GetSingleArgument<int64_t>("has_trip_count", 0)), | 
|  | has_cond_(this->template GetSingleArgument<int64_t>("has_cond", 0)), | 
|  | save_scopes_( | 
|  | this->template GetSingleArgument<int64_t>("save_scopes", 0)), | 
|  | disable_scopes_( | 
|  | this->template GetSingleArgument<int64_t>("disable_scopes", 0)), | 
|  | num_loop_carried_deps_(this->template GetSingleArgument<int64_t>( | 
|  | "num_loop_carried_deps", | 
|  | -1)) { | 
|  | CAFFE_ENFORCE( | 
|  | this->template HasSingleArgumentOfType<NetDef>("body"), | 
|  | "body net must be specified in ONNXWhile operator"); | 
|  | if (disable_scopes_) { | 
|  | CAFFE_ENFORCE( | 
|  | !save_scopes_, "Cannot save scopes when disable_scopes=True"); | 
|  | } | 
|  | body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef()); | 
|  | static int64_t counter = -1; | 
|  | if (!body_net_def_.has_name()) { | 
|  | if (counter == -1) { | 
|  | ++counter; | 
|  | body_net_def_.set_name("loop_net"); | 
|  | } else { | 
|  | ++counter; | 
|  | body_net_def_.set_name("loop_net." + c10::to_string(counter)); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | bool RunOnDevice() { | 
|  | return DispatchHelper<TensorTypes<int, bool, long>>::call(this, Input(1)); | 
|  | } | 
|  |  | 
|  | // Operator | 
|  | //  Inputs: max trip count, condition, initial loop-carried dependencies | 
|  | //  Outputs: Final loop-carried dependencies, scan_outputs | 
|  | // Body | 
|  | //  Inputs: iteration number, condition, loop-carried dependencies | 
|  | //  Outputs: condition, loop-carried dependencies, scan_outputs | 
|  | template <typename CondVarType> | 
|  | bool DoRunWithType() { | 
|  | // Clear workspaces from the previous invocations of the loop | 
|  | // and setup a local scope for the first iteration | 
|  | ws_stack_.clear(); | 
|  | auto loop_ws = !disable_scopes_ | 
|  | ? ws_stack_.pushForwardWorkspace(parent_ws_).get() | 
|  | : parent_ws_; | 
|  |  | 
|  | constexpr int64_t num_inputs_before_lcds = 2; | 
|  | // First input is the maximumt trip count. Second input is the condition | 
|  | // variable (for the first iteration). The rest of the inputs are | 
|  | // loop-carried dependencies. | 
|  | int64_t num_loop_carried_deps; | 
|  | if (num_loop_carried_deps_ != -1) { | 
|  | num_loop_carried_deps = num_loop_carried_deps_; | 
|  | } else { | 
|  | num_loop_carried_deps = InputSize() - num_inputs_before_lcds; | 
|  | } | 
|  | int64_t max_trip_count = *Input(0).template data<int64_t>(); | 
|  | const bool first_iter_condition = *Input(1).template data<CondVarType>(); | 
|  |  | 
|  | scope_ = std::make_shared<LocalScope>( | 
|  | loop_ws, body_net_def_, num_loop_carried_deps); | 
|  |  | 
|  | // Body graph has 1+N+K outputs: recalculated condition variable, N | 
|  | // loop-carried dependencies, and K scan_outputs | 
|  | int num_scan_outputs = | 
|  | scope_->net()->external_output().size() - num_loop_carried_deps - 1; | 
|  |  | 
|  | CAFFE_ENFORCE_GE( | 
|  | num_scan_outputs, | 
|  | 0, | 
|  | "Body graph must have N+K outputs, where N is the number " | 
|  | "of loop-carried dependencies and K is the number of scan " | 
|  | "outputs"); | 
|  |  | 
|  | // Copy initial loop-carried dependencies | 
|  | for (const auto i : c10::irange(num_loop_carried_deps)) { | 
|  | scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds)); | 
|  | } | 
|  |  | 
|  | // Initialize iteration variable | 
|  | scope_->set_iteration(0ll); | 
|  |  | 
|  | // Initialize input condition variable | 
|  | scope_->template set_input_condition<CondVarType>(first_iter_condition); | 
|  |  | 
|  | auto valid_iter_num = [this, max_trip_count](int64_t i) { | 
|  | if (has_trip_count_) { | 
|  | return i < max_trip_count; | 
|  | } else { | 
|  | return true; | 
|  | } | 
|  | }; | 
|  |  | 
|  | auto condition_true = [this, first_iter_condition]( | 
|  | int64_t i, bool cond_value) { | 
|  | if (has_cond_) { | 
|  | if (i == 0) { | 
|  | return (bool)first_iter_condition; | 
|  | } else { | 
|  | return cond_value; | 
|  | } | 
|  | } else { | 
|  | return true; | 
|  | } | 
|  | }; | 
|  |  | 
|  | // Allocate scan_outputs for zero-iteration case | 
|  | for (const auto i : c10::irange(num_scan_outputs)) { | 
|  | Output(i + num_loop_carried_deps)->Resize(0); | 
|  | Output(i + num_loop_carried_deps)->template mutable_data<int32_t>(); | 
|  | } | 
|  |  | 
|  | // Use this to keep track of the sizes of the scan outputs and validate | 
|  | // they're the same across iterations. | 
|  | std::vector<std::vector<int64_t>> scan_outputs_sizes; | 
|  |  | 
|  | Workspace* cur_ws = nullptr; | 
|  | bool cur_output_condition = false; | 
|  |  | 
|  | while (true) { | 
|  | int64_t itr = scope_->iteration(); | 
|  | if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) { | 
|  | if (!scope_->net()->Run()) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | cur_ws = scope_->workspace(); | 
|  | cur_output_condition = scope_->template output_condition<CondVarType>(); | 
|  | if (save_scopes_) { | 
|  | loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get(); | 
|  | scope_ = std::make_shared<LocalScope>( | 
|  | loop_ws, body_net_def_, num_loop_carried_deps); | 
|  | } | 
|  |  | 
|  | // Copy forward loop-carried dependencies | 
|  | for (const auto i : c10::irange(num_loop_carried_deps)) { | 
|  | Blob* b = cur_ws->GetBlob(scope_->net()->external_output()[i + 1]); | 
|  | const Tensor& t = b->template Get<Tensor>(); | 
|  | scope_->lcd_tensor(i)->CopyFrom(t); | 
|  | } | 
|  | // Copy out scan_outputs | 
|  | for (const auto i : c10::irange(num_scan_outputs)) { | 
|  | int net_output_idx = i + 1 + num_loop_carried_deps; | 
|  | const Tensor& scan_output = | 
|  | cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx]) | 
|  | ->template Get<Tensor>(); | 
|  | auto* scan_output_target = Output(i + num_loop_carried_deps); | 
|  | if (itr == 0) { | 
|  | auto dims = scan_output.sizes().vec(); | 
|  | scan_outputs_sizes.push_back(dims); | 
|  | dims.insert(dims.begin(), 1); | 
|  | scan_output_target->Resize(dims); | 
|  | scan_output_target->CopyFrom(scan_output); | 
|  | } else { | 
|  | auto dims = scan_output.sizes().vec(); | 
|  | CAFFE_ENFORCE_EQ( | 
|  | dims, | 
|  | scan_outputs_sizes[i], | 
|  | "Size of scan output changed across iterations"); | 
|  | dims.insert(dims.begin(), itr); | 
|  | scan_output_target->Extend(1, 100); | 
|  |  | 
|  | int64_t timestep_size = 1; | 
|  | for (const int64_t t : scan_outputs_sizes[i]) { | 
|  | timestep_size *= t; | 
|  | } | 
|  |  | 
|  | const void* src_data = scan_output.raw_data(); | 
|  | auto& sot_meta = scan_output_target->dtype(); | 
|  | void* dst_data = | 
|  | (char*)scan_output_target->raw_mutable_data(sot_meta) + | 
|  | timestep_size * scan_output.itemsize() * itr; | 
|  | memcpy(dst_data, src_data, timestep_size * scan_output.itemsize()); | 
|  | } | 
|  | } | 
|  | scope_->set_iteration(itr + 1ll); | 
|  | scope_->template set_input_condition<CondVarType>(cur_output_condition); | 
|  | } else { | 
|  | break; | 
|  | } | 
|  | } | 
|  |  | 
|  | // Copy out final loop-carried dependencies | 
|  | for (const auto i : c10::irange(num_loop_carried_deps)) { | 
|  | Output(i)->CopyFrom(*scope_->lcd_tensor(i)); | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | class LocalScope { | 
|  | public: | 
|  | LocalScope(Workspace* loop_ws, const NetDef& body_net_def, size_t num_lcds) | 
|  | : loop_ws_(loop_ws) { | 
|  | CAFFE_ENFORCE(loop_ws_, "Failed to initialize local loop workspace"); | 
|  |  | 
|  | // Create loop-carried deps in Workspace | 
|  | lcd_tensors_.clear(); | 
|  | // NOLINTNEXTLINE(clang-diagnostic-sign-compare) | 
|  | for (int i = 2; i < num_lcds + 2; ++i) { | 
|  | Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i)); | 
|  | Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType()); | 
|  | lcd_tensors_.push_back(t); | 
|  | } | 
|  | // First output is the iteration variable | 
|  | auto* iteration_var_blob = | 
|  | loop_ws_->CreateBlob(body_net_def.external_input(0)); | 
|  | iteration_var_ = | 
|  | BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType()); | 
|  |  | 
|  | input_condition_var_ = BlobGetMutableTensor( | 
|  | loop_ws_->CreateBlob(body_net_def.external_input(1)), | 
|  | Context::GetDeviceType()); | 
|  |  | 
|  | auto* condition_var_blob = | 
|  | loop_ws_->CreateBlob(body_net_def.external_output(0)); | 
|  | condition_var_ = | 
|  | BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType()); | 
|  | condition_var_->Resize(1); | 
|  | condition_var_->template mutable_data<bool>(); | 
|  |  | 
|  | body_net_ = loop_ws_->GetNet(body_net_def.name()); | 
|  | if (!body_net_) { | 
|  | body_net_ = loop_ws_->CreateNet(body_net_def, true); | 
|  | } | 
|  | CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet"); | 
|  | } | 
|  |  | 
|  | NetBase* net() const { | 
|  | return body_net_; | 
|  | } | 
|  |  | 
|  | Workspace* workspace() const { | 
|  | return loop_ws_; | 
|  | } | 
|  |  | 
|  | int64_t iteration() const { | 
|  | auto* iteration_var_ptr = | 
|  | iteration_var_->template mutable_data<int64_t>(); | 
|  | return *iteration_var_ptr; | 
|  | } | 
|  |  | 
|  | Tensor* lcd_tensor(int idx) { | 
|  | return lcd_tensors_[idx]; | 
|  | } | 
|  |  | 
|  | void set_iteration(int64_t itr) { | 
|  | iteration_var_->Resize(); | 
|  | auto* iteration_var_ptr = | 
|  | iteration_var_->template mutable_data<int64_t>(); | 
|  | *iteration_var_ptr = itr; | 
|  | } | 
|  |  | 
|  | template <typename CondVarType> | 
|  | void set_input_condition(bool cond_value) { | 
|  | input_condition_var_->Resize(1); | 
|  | auto* input_condition_var_ptr = | 
|  | input_condition_var_->template mutable_data<CondVarType>(); | 
|  | *input_condition_var_ptr = cond_value; | 
|  | } | 
|  |  | 
|  | template <typename CondVarType> | 
|  | bool output_condition() const { | 
|  | auto* condition_var_ptr = | 
|  | condition_var_->template mutable_data<CondVarType>(); | 
|  | return *condition_var_ptr; | 
|  | } | 
|  |  | 
|  | private: | 
|  | Workspace* loop_ws_; | 
|  |  | 
|  | NetBase* body_net_; // owned by a workspace | 
|  | Tensor* iteration_var_; | 
|  | Tensor* input_condition_var_; | 
|  | Tensor* condition_var_; | 
|  |  | 
|  | std::vector<Tensor*> lcd_tensors_; | 
|  | }; | 
|  |  | 
|  | NetDef body_net_def_; | 
|  | Workspace* parent_ws_; | 
|  | detail::WorkspaceStack ws_stack_; | 
|  |  | 
|  | bool has_trip_count_; | 
|  | bool has_cond_; | 
|  | bool save_scopes_; | 
|  | bool disable_scopes_; | 
|  | int64_t num_loop_carried_deps_; | 
|  |  | 
|  | std::shared_ptr<LocalScope> scope_; | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H |