| graph(%a.1_data : Tensor, | |
| %a.1_mask : Tensor, | |
| %a.1_dims : Tensor, | |
| %b_data : Tensor, | |
| %b_mask : Tensor, | |
| %b_dims : Tensor): | |
| %6 : int = prim::Constant[value=1]() | |
| %7 : Tensor = aten::gt(%a.1_data, %b_data) | |
| %8 : Tensor = aten::mul(%a.1_mask, %b_mask) | |
| %9 : Long() = prim::NumToTensor(%6) | |
| %alpha.1 : float = prim::Float(%9) | |
| %data.1 : Tensor = aten::add(%a.1_data, %b_data, %alpha.1) | |
| %mask.1 : Tensor = aten::mul(%a.1_mask, %b_mask) | |
| %dims.1 : Tensor = aten::__or__(%a.1_dims, %b_dims) | |
| %14 : Long() = prim::NumToTensor(%6) | |
| %alpha : float = prim::Float(%14) | |
| %data : Tensor = aten::sub(%a.1_data, %b_data, %alpha) | |
| %mask : Tensor = aten::mul(%a.1_mask, %b_mask) | |
| %dims : Tensor = aten::__or__(%a.1_dims, %b_dims) | |
| %19 : bool = prim::Constant[value=1]() | |
| %20 : int = prim::Constant[value=1]() | |
| %21 : Tensor = aten::type_as(%8, %7) | |
| %data.2 : Tensor = aten::mul(%7, %21) | |
| %23 : int = aten::dim(%data.2) | |
| %24 : bool = aten::eq(%23, %20) | |
| %cond_data : Tensor, %cond_mask : Tensor = prim::If(%24) | |
| block0(): | |
| %27 : int = aten::dim(%data.1) | |
| %28 : int = aten::sub(%27, %20) | |
| %data.4 : Tensor = prim::Loop(%28, %19, %data.2) | |
| block0(%30 : int, %31 : Tensor): | |
| %32 : int = aten::dim(%31) | |
| %data.3 : Tensor = aten::unsqueeze(%31, %32) | |
| -> (%19, %data.3) | |
| %cond_data.1 : Tensor = aten::expand_as(%data.4, %data.1) | |
| %cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask.1) | |
| -> (%cond_data.1, %cond_mask.1) | |
| block1(): | |
| -> (%data.2, %data.2) | |
| %res_data : Tensor = aten::where(%cond_data, %data.1, %data) | |
| %res_mask : Tensor = aten::where(%cond_mask, %mask.1, %mask) | |
| %res_dims : Tensor = aten::__or__(%dims.1, %dims) | |
| %39 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims) | |
| return (%39) |