| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" |
| |
| #include <cstdlib> |
| #include <numeric> |
| #include <vector> |
| |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" |
| #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" |
| #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/window_util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/logging.h" |
| |
| namespace xla { |
| namespace gpu { |
| |
| namespace { |
| |
| HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape, |
| HloInstruction* lhs, HloInstruction* rhs, |
| const Window& window, |
| const ConvolutionDimensionNumbers& dnums, |
| int64 feature_group_count, |
| const OpMetadata& metadata) { |
| HloComputation* computation = lhs->parent(); |
| |
| // This call returns a tuple of (conv_result, scratch_memory), where |
| // conv_result is the actual result of the convolution, and scratch_memory is |
| // temporary memory used by cudnn. |
| // |
| // At the moment, we don't know how much scratch memory this conv is going to |
| // use, so we put u8[0] in this place. Later on another pass will choose |
| // which conv algorithm to use, and at that point we'll modify the shape of |
| // this second tuple element. |
| Shape call_shape = |
| ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); |
| |
| HloInstruction* custom_call = computation->AddInstruction( |
| HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); |
| custom_call->set_window(window); |
| custom_call->set_convolution_dimension_numbers(dnums); |
| custom_call->set_feature_group_count(feature_group_count); |
| custom_call->set_metadata(metadata); |
| return custom_call; |
| } |
| |
| bool CanImplementAsGpuForwardConv(HloInstruction* conv) { |
| const ConvolutionDimensionNumbers& dnums = |
| conv->convolution_dimension_numbers(); |
| if (dnums.input_spatial_dimensions_size() > 3) { |
| return false; |
| } |
| |
| // CuDNN does not accept zero-element arguments |
| if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || |
| ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { |
| return false; |
| } |
| |
| // CuDNN can perform either cross correlation (no reversal), |
| // or convolution (all dimensions reversed). |
| if (dnums.input_spatial_dimensions_size() == 2 |
| ? !window_util::AllOrNoneReversed(conv->window()) |
| : window_util::HasWindowReversal(conv->window())) { |
| return false; |
| } |
| return true; |
| } |
| |
| // Try to match a backward filter pattern that contains "conv". |
| // Precondition: "conv" is a kConvolution. |
| std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*> |
| MatchBackwardFilter(HloInstruction* conv) { |
| const auto no_match_result = |
| std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); |
| |
| // Step 1: match the instruction pattern without considering the paddings and |
| // dimension numbers just yet. We may need some generic pattern matcher |
| // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h |
| // |
| // Backward filter convolution is implemented in XLA as the forward |
| // convolution of padded activations and dilated gradients. Padding on |
| // activations and dilation on gradients are specified in the "window" field |
| // of the forward convolution. |
| // |
| // activations gradients |
| // \ / |
| // v v |
| // Convolution |
| // conv |
| CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); |
| |
| // Step 2: match paddings and dimension numbers of the forward convolution. |
| const ConvolutionDimensionNumbers& conv_dnums = |
| conv->convolution_dimension_numbers(); |
| auto input_batch_dim = conv_dnums.input_batch_dimension(); |
| auto input_feature_dim = conv_dnums.input_feature_dimension(); |
| auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); |
| auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); |
| auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); |
| auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); |
| auto output_batch_dim = conv_dnums.output_batch_dimension(); |
| auto output_feature_dim = conv_dnums.output_feature_dimension(); |
| auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); |
| |
| for (const WindowDimension& window_dim : conv->window().dimensions()) { |
| if (window_dim.stride() != 1) { |
| VLOG(1) << "Forward convolution's window " |
| << conv->window().ShortDebugString() |
| << " should have stride of 1."; |
| return no_match_result; |
| } |
| if (window_dim.base_dilation() != 1) { |
| VLOG(1) << "Forward convolution's window " |
| << conv->window().ShortDebugString() |
| << " should have no base (LHS) dilation."; |
| return no_match_result; |
| } |
| if (window_dim.padding_low() < 0) { |
| VLOG(1) << "Padding low should be non-negative."; |
| return no_match_result; |
| } |
| if (window_dim.window_reversal()) { |
| VLOG(1) << "Window reversal field not supported"; |
| return no_match_result; |
| } |
| // Padding high will be checked in Step 3. |
| } |
| if (input_batch_dim == output_batch_dim && |
| !window_util::HasWindowDilation(conv->window())) { |
| VLOG(1) << conv->ToString() |
| << " is a regular forward convolution. No need " |
| "to fold it to a backward filter convolution."; |
| return no_match_result; |
| } |
| auto rhs_in = |
| conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim); |
| if (conv->feature_group_count() > 1 && rhs_in == 1 && |
| input_batch_dim == output_batch_dim) { |
| VLOG(1) << conv->ToString() |
| << " is a depthwise forward convolution. No need to fold to " |
| "backward filter."; |
| return no_match_result; |
| } |
| |
| // Step 3: fuse the matched HLOs into a backward convolution instruction. |
| // |
| // Compute the window of the backward convolution. |
| Window backward_conv_window; |
| for (int i = 0; i < input_spatial_dims.size(); ++i) { |
| WindowDimension* dim = backward_conv_window.add_dimensions(); |
| // The window size of the backward convolution equals the output size of the |
| // forward convolution. |
| int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); |
| dim->set_size(filter_size); |
| // The window stride equals the window dilation of the forward convolution. |
| dim->set_stride(conv->window().dimensions(i).window_dilation()); |
| // The window's low padding is the same as the low padding of the |
| // activations. |
| dim->set_padding_low(conv->window().dimensions(i).padding_low()); |
| dim->set_base_dilation(1); |
| dim->set_window_dilation(1); |
| |
| int64 input_size = |
| conv->operand(0)->shape().dimensions(input_spatial_dims[i]); |
| int64 output_size = conv->window().dimensions(i).size(); |
| // Compute the range of the amount of valid high padding. We first compute |
| // min_padding_high, the amount of padding on the right/bottom to ensure the |
| // last patch ends at the border, i.e., |
| // |
| // input_size + dim->padding_low() + min_padding_high |
| // = (output_size - 1) * stride + filter_size |
| // |
| // Because convolution ignores trailing incomplete windows, any amount of |
| // padding high from min_padding_high to min_padding_high+stride-1 |
| // (max_padding_high) has the same effect. |
| int64 padded_input_size = filter_size + (output_size - 1) * dim->stride(); |
| int64 min_padding_high = |
| padded_input_size - input_size - dim->padding_low(); |
| int64 max_padding_high = min_padding_high + dim->stride() - 1; |
| CHECK_GE(dim->padding_low(), 0); |
| // In practice, since cuDNN convolution only supports even padding, we make |
| // the amount of high padding the same as the amount of low padding as long |
| // as it is between min_padding_high and max_padding_high. If it is not in |
| // that range, we pick the one that's closest to dim->padding_low() and let |
| // GpuConvPaddingLegalization canonicalize the resultant backward |
| // convolution later. Picking the closest one minimizes the cost of the kPad |
| // instruction to be inserted by GpuConvPaddingLegalization. |
| if (dim->padding_low() >= min_padding_high && |
| dim->padding_low() <= max_padding_high) { |
| dim->set_padding_high(dim->padding_low()); |
| } else { |
| if (dim->padding_low() < min_padding_high) { |
| dim->set_padding_high(min_padding_high); |
| } else { |
| dim->set_padding_high(max_padding_high); |
| } |
| } |
| if (dim->padding_high() < 0) { |
| LOG(WARNING) |
| << "Fusing this pattern to backward filter convolution would cause " |
| "negative padding (" |
| << dim->padding_high() |
| << ") on right/bottom of the weight gradients, which is not " |
| "supported by GpuConvPaddingLegalization (b/32744257). " |
| "Falling back to " |
| "unfused convolution for instruction: " |
| << conv->ToString(); |
| return no_match_result; |
| } |
| } |
| |
| // Restore the dimension numbers of the backward convolution from the forward |
| // convolution. The two activation dimensions are reversed (batch and |
| // feature). |
| ConvolutionDimensionNumbers backward_conv_dnums; |
| backward_conv_dnums.set_input_batch_dimension(input_feature_dim); |
| backward_conv_dnums.set_input_feature_dimension(input_batch_dim); |
| for (int i = 0; i < input_spatial_dims.size(); ++i) { |
| backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); |
| } |
| backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); |
| backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); |
| for (int i = 0; i < kernel_spatial_dims.size(); ++i) { |
| backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); |
| } |
| // The dimension numbering of the output of the forward convolution (before |
| // transposition) is the same as that of the activations (according to the |
| // semantics of kConvolution). The batch dimension of the activations should |
| // be treated as the input feature dimension, and the feature dimension should |
| // be treated as the output feature. |
| backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); |
| backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); |
| for (int i = 0; i < output_spatial_dims.size(); ++i) { |
| backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); |
| } |
| |
| HloInstruction* lhs = conv->mutable_operand(0); |
| if (conv->feature_group_count() == 1) { |
| return std::make_tuple(true, backward_conv_window, backward_conv_dnums, |
| lhs); |
| } |
| |
| int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension(); |
| int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension(); |
| |
| int64 input_batch = lhs->shape().dimensions(input_batch_dimension); |
| int64 input_feature = lhs->shape().dimensions(input_feature_dimension); |
| |
| // Reshape batch_dim G*N -> [G,N] |
| std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions()); |
| auto num_groups = conv->feature_group_count(); |
| CHECK_EQ(input_batch % num_groups, 0) |
| << "Input batch should be an exact multiple of feature group count"; |
| reshape_dims[input_batch_dimension] = |
| reshape_dims[input_batch_dimension] / num_groups; |
| reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups); |
| |
| HloComputation* c = conv->parent(); |
| HloInstruction* lhs_reshape_1 = |
| c->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), |
| lhs)); |
| |
| // Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G, |
| // C/G, H, W] |
| std::vector<int64> transpose_dims(lhs_reshape_1->shape().dimensions_size()); |
| std::iota(transpose_dims.begin(), transpose_dims.end(), 0); |
| transpose_dims.erase(transpose_dims.begin() + input_batch_dimension); |
| transpose_dims.insert(transpose_dims.begin() + input_feature_dimension, |
| input_batch_dimension); |
| std::vector<int64> transpose_reshape_dims = |
| SpanToVector(lhs_reshape_1->shape().dimensions()); |
| transpose_reshape_dims.erase(transpose_reshape_dims.begin() + |
| input_batch_dimension); |
| transpose_reshape_dims.insert( |
| transpose_reshape_dims.begin() + input_feature_dimension, num_groups); |
| |
| HloInstruction* lhs_transpose = |
| c->AddInstruction(HloInstruction::CreateTranspose( |
| ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(), |
| transpose_reshape_dims), |
| lhs_reshape_1, transpose_dims)); |
| |
| // Merge [G,C/G] -> [C] |
| Shape new_shape = lhs_transpose->shape(); |
| new_shape.DeleteDimension(input_feature_dimension); |
| new_shape.set_dimensions(input_feature_dimension, |
| input_feature * conv->feature_group_count()); |
| HloInstruction* lhs_reshape_2 = c->AddInstruction( |
| HloInstruction::CreateReshape(new_shape, lhs_transpose)); |
| return std::make_tuple(true, backward_conv_window, backward_conv_dnums, |
| lhs_reshape_2); |
| } |
| |
| // Try to match a backward input pattern that contains "conv". |
| // Precondition: "conv" is a kConvolution. |
| std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*> |
| MatchBackwardInput(HloInstruction* conv) { |
| const auto no_match_result = |
| std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); |
| |
| // TODO: Theoretically cuDNN supports grouped convolutions also |
| // for the backward input convolution, but based on the cudnn's current state |
| // there is not much performance improvement when using the |
| // cudnn backward input API for grouped conv. |
| // This needs to be re-evaluated for future cuDNN versions. |
| // Note that we already have the necessary code down below, the only thing to |
| // enable it is to remove the following early return. |
| if (conv->feature_group_count() > 1) { |
| return no_match_result; |
| } |
| |
| // Match instruction pattern. |
| CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); |
| HloInstruction* reverse_filter = conv->mutable_operand(1); |
| ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); |
| |
| // Match BackwardInput for a depthwise convolution and thunk it to forward |
| // convolution Output feature dimension and input feature dimension has been |
| // swapped in the bridge. Hence to get the actual input features we need to |
| // query the output feature dimension |
| auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension(); |
| auto kernel_out_features = |
| reverse_filter->shape().dimensions(kernel_out_feature_dim); |
| |
| // For a depthwise convolution, the input features must be equal to the |
| // feature_group_count. We can leverage this property to match a depthwise |
| // convolution and thunk it to forward conv |
| if (conv->feature_group_count() > 1 && |
| kernel_out_features == conv->feature_group_count()) { |
| return no_match_result; |
| } |
| |
| // We pattern-match to a backwards input conv if: |
| // |
| // - all spatial dims of the filter are reversed |
| // |
| // OR |
| // |
| // - filter is 1x1 or a constant AND |
| // - conv has base dilation (otherwise this is just a regular forward conv). |
| // |
| // The final criterion above is just for canonicalization; cudnn seems to run |
| // just as fast if we canonicalize 1x1/constant filters without base dilation |
| // to forward or backward convs. We canonicalize to forward conv because (a) |
| // it's more natural (constant filters usually show up when doing inference, |
| // and having backwards convolutions in inference graphs would be weird), and |
| // (b) cudnn has special fusions for forward conv plus bias and activation, |
| // and we want to pattern-match to that after running this pass. |
| bool is_reversed_filter = |
| reverse_filter->opcode() == HloOpcode::kReverse && |
| absl::c_is_permutation(dnums.kernel_spatial_dimensions(), |
| reverse_filter->dimensions()); |
| bool is_1x1_filter = |
| absl::c_all_of(conv->window().dimensions(), |
| [](const WindowDimension& d) { return d.size() == 1; }); |
| if (!is_reversed_filter && |
| !(window_util::HasBaseDilation(conv->window()) && |
| (reverse_filter->IsConstant() || is_1x1_filter))) { |
| VLOG(1) << "Can't match to backwards convolution. Either filter is not " |
| "kReverse, or it's not a base-dilated conv with a 1x1 or " |
| "constant filter."; |
| return no_match_result; |
| } |
| |
| // Match padding and dilation of the forward convolution. |
| for (const WindowDimension& window_dim : conv->window().dimensions()) { |
| if (window_dim.stride() != 1) { |
| VLOG(1) << "Forward convolution's window " |
| << conv->window().ShortDebugString() |
| << " should have stride of 1."; |
| return no_match_result; |
| } |
| if (window_dim.window_dilation() != 1) { |
| VLOG(1) << "Forward convolution's window " |
| << conv->window().ShortDebugString() |
| << " should have no window dilation."; |
| return no_match_result; |
| } |
| if (window_dim.window_reversal()) { |
| VLOG(1) << "Window reversal field not supported"; |
| return no_match_result; |
| } |
| } |
| |
| const auto& input_spatial_dims = dnums.input_spatial_dimensions(); |
| const auto& output_spatial_dims = dnums.output_spatial_dimensions(); |
| CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); |
| CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); |
| |
| const Window& old_window = conv->window(); |
| Window new_window = old_window; |
| for (size_t i = 0; i < input_spatial_dims.size(); ++i) { |
| // Restore backward convolution's padding config from the matched pattern. |
| // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we |
| // convert backward input convolution to a variant of forward convolution. |
| // |
| // The stride of the backward convolution |
| // = the base dilation factor of the forward convolution |
| auto dim = new_window.mutable_dimensions(i); |
| dim->set_stride(old_window.dimensions(i).base_dilation()); |
| dim->set_base_dilation(1); |
| |
| // The low padding = kernel_size - 1 - low padding on the gradients |
| // Make sure the low padding is not negative. |
| auto kernel_size = old_window.dimensions(i).size(); |
| auto backward_padding_low = |
| kernel_size - 1 - old_window.dimensions(i).padding_low(); |
| if (backward_padding_low < 0) { |
| LOG(WARNING) |
| << "The low padding of the backward convolution would be negative (" |
| << backward_padding_low |
| << "), which isn't supported by GpuConvPaddingLegalization " |
| "for now (b/32744257)."; |
| return no_match_result; |
| } |
| dim->set_padding_low(backward_padding_low); |
| |
| // Compute the range of the amount of padding on the right/bottom of the |
| // activations. XLA's convolution requires all patches to be within the |
| // padded base. This gives us flexiblity to choose the amount of high |
| // padding from a set of values without changing the result of the backward |
| // convolution. The minimum amount (min_padding_high) makes the last patch |
| // end at the border. The maximum amount (max_padding_high) equals |
| // min_padding_high+stride-1 -- max_padding_high+1 would cause the output |
| // size to change. |
| auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); |
| auto output_size = |
| conv->operand(0)->shape().dimensions(input_spatial_dims[i]); |
| auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); |
| auto total_pad_size = padded_input_size - unpadded_input_size; |
| auto min_padding_high = total_pad_size - backward_padding_low; |
| auto max_padding_high = min_padding_high + dim->stride() - 1; |
| |
| if (backward_padding_low >= min_padding_high && |
| backward_padding_low <= max_padding_high) { |
| // In the best case (most likely), if backward_padding_low is in the range |
| // of the amounts of valid high padding, we choose backward_padding_low |
| // because cuDNN supports even padding only. |
| dim->set_padding_high(backward_padding_low); |
| } else { |
| // Otherwise, we choose the amount that's closest to backward_padding_low, |
| // and GpuConvPaddingLegalization will later insert kSlice |
| // instructions to enforce even padding. |
| // |
| // For example, consider the backward convolution pattern |
| // |
| // ab xy |
| // | pad | reverse |
| // .a.b yx |
| // \ / |
| // ABC |
| // |
| // The amount of low padding on activations (in backward convolution) is |
| // backward_padding_low = kernel_size - 1 - forward_padding_low |
| // = 2 - 1 - 1 = 0 |
| // |
| // The amount of padding high must be between 1 and 2, in order to make |
| // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in |
| // the range of [1,2], so we pick the closest valid amount of padding |
| // high, which is 1 in this case. Therefore, we fuse the above pattern to |
| // |
| // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1) |
| if (backward_padding_low < min_padding_high) { |
| dim->set_padding_high(min_padding_high); |
| } else { |
| dim->set_padding_high(max_padding_high); |
| } |
| } |
| // GpuConvPaddingLegalization doesn't handle backward input |
| // convolution with negative padding for now. So fall back to unfused |
| // convolution in case of negative padding. For example, |
| // ABCD = Conv(abc, reverse(xy), padding_high=2) |
| // could be fused to |
| // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) |
| // with positive padding low but negative padding high. |
| if (dim->padding_high() < 0) { |
| LOG(WARNING) << "Fusing this pattern to backward convolution would cause " |
| "negative padding (" |
| << dim->padding_high() |
| << ") on right/bottom of the activations, which is not " |
| "supported by GpuConvPaddingLegalization (b/32744257). " |
| "Falling back to unfused convolution for instruction: " |
| << conv->ToString(); |
| return no_match_result; |
| } |
| } |
| |
| // OK, it's a match! Switch the input feature dimension with the output |
| // feature dimension. Also switch the output with the input. This is the way |
| // cuDNN expects it to be. |
| auto conv_dnums = conv->convolution_dimension_numbers(); |
| dnums.set_kernel_input_feature_dimension( |
| conv_dnums.kernel_output_feature_dimension()); |
| dnums.set_kernel_output_feature_dimension( |
| conv_dnums.kernel_input_feature_dimension()); |
| for (int i = 0; i < input_spatial_dims.size(); ++i) { |
| dnums.set_input_spatial_dimensions(i, |
| conv_dnums.output_spatial_dimensions(i)); |
| dnums.set_output_spatial_dimensions(i, |
| conv_dnums.input_spatial_dimensions(i)); |
| } |
| dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension()); |
| dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension()); |
| dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension()); |
| dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension()); |
| |
| // If we matched against a constant, we need to add a reverse op that can be |
| // subsumed by the cuDNN call. algebraic-simplifier will later remove any |
| // unnecessary reverses. |
| if (reverse_filter->opcode() != HloOpcode::kReverse && |
| reverse_filter->IsConstant()) { |
| // Create a double-reverse, which is a nop. |
| HloComputation* c = conv->parent(); |
| reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( |
| reverse_filter->shape(), reverse_filter, |
| AsInt64Slice(dnums.kernel_spatial_dimensions()))); |
| reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( |
| reverse_filter->shape(), reverse_filter, |
| AsInt64Slice(dnums.kernel_spatial_dimensions()))); |
| TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter)); |
| } |
| |
| // Calculate the 'rhs' that goes into the backward input convolution. |
| HloInstruction* rhs = reverse_filter; |
| // One reverse is subsumed by the cuDNN call. |
| if (rhs->opcode() == HloOpcode::kReverse) { |
| rhs = rhs->mutable_operand(0); |
| } |
| if (conv->feature_group_count() == 1) { |
| return std::make_tuple(true, new_window, dnums, rhs); |
| } |
| |
| // Handle grouped convolutions. Because we swapped the input feature dimension |
| // with the output feature dimension, we need to also reshape the kernel so |
| // that the 'feature_group_count' parameter still makes sense. The |
| // 'feature_group_count' parameter essentially specifies how often the |
| // 'kernel_input_feature_dimension' is repeated. So when we swap these |
| // dimensions, we need to divide the new 'kernel_input_feature_dimension' by |
| // 'feature_group_count' and multiply the new |
| // 'kernel_output_feature_dimension' by 'feature_group_count'. |
| int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); |
| int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); |
| |
| // In the backward convolution case, the spatial dimensions become the |
| // feature dimensions, and we are guaranteed that the spatial dimensions are |
| // adjacent. |
| CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); |
| int64 input_features = rhs->shape().dimensions(input_feature_dimension); |
| int64 output_features = rhs->shape().dimensions(output_feature_dimension); |
| |
| // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G, |
| // out_depth / G] |
| std::vector<int64> reshape_dims = SpanToVector(rhs->shape().dimensions()); |
| auto num_groups = conv->feature_group_count(); |
| CHECK_EQ(input_features % num_groups, 0) |
| << "Input feature count should be an exact multiple of feature group " |
| "count"; |
| reshape_dims[input_feature_dimension] = |
| reshape_dims[input_feature_dimension] / num_groups; |
| reshape_dims.insert(reshape_dims.begin() + input_feature_dimension, |
| num_groups); |
| |
| HloComputation* c = conv->parent(); |
| rhs = c->AddInstruction(HloInstruction::CreateReshape( |
| ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs)); |
| |
| // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ..., |
| // in_depth/G, G, out_depth / G] |
| std::vector<int64> transpose_dims(rhs->shape().dimensions_size()); |
| std::iota(transpose_dims.begin(), transpose_dims.end(), 0); |
| transpose_dims.erase(transpose_dims.begin() + input_feature_dimension); |
| transpose_dims.insert(transpose_dims.begin() + output_feature_dimension, |
| input_feature_dimension); |
| std::vector<int64> transpose_reshape_dims = |
| SpanToVector(rhs->shape().dimensions()); |
| transpose_reshape_dims.erase(transpose_reshape_dims.begin() + |
| input_feature_dimension); |
| transpose_reshape_dims.insert( |
| transpose_reshape_dims.begin() + output_feature_dimension, num_groups); |
| rhs = c->AddInstruction(HloInstruction::CreateTranspose( |
| ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims), |
| rhs, transpose_dims)); |
| |
| // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ..., |
| // in_depth/G, out_depth] |
| Shape new_shape = rhs->shape(); |
| new_shape.DeleteDimension(output_feature_dimension); |
| new_shape.set_dimensions(output_feature_dimension, |
| output_features * num_groups); |
| rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); |
| return std::make_tuple(true, new_window, dnums, rhs); |
| } |
| |
| CudnnConvBackendConfig GetDefaultBackendConfig() { |
| CudnnConvBackendConfig config; |
| config.set_conv_result_scale(1); |
| return config; |
| } |
| |
| // Helper function to create a custom_call instruction to replace the given |
| // conv instruction |
| static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) { |
| bool match; |
| Window window; |
| ConvolutionDimensionNumbers dnums; |
| HloInstruction* rhs; |
| HloInstruction* lhs; |
| |
| std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); |
| if (match) { |
| return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(), |
| conv->mutable_operand(0), rhs, window, dnums, |
| conv->feature_group_count(), conv->metadata()); |
| } |
| |
| std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv); |
| if (match) { |
| return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs, |
| conv->mutable_operand(1), window, dnums, |
| conv->feature_group_count(), conv->metadata()); |
| } |
| |
| // If all else fails, try a forward convolution. |
| if (CanImplementAsGpuForwardConv(conv)) { |
| if (primitive_util::IsIntegralType( |
| conv->operand(0)->shape().element_type())) { |
| // In addition to replacing a convolution instruction with |
| // a custom call, integer convolutions must have this pattern to match |
| // CuDNN semantics: |
| // conv<InputT=int32, ResultT=int32>( |
| // convert<int32>(int8_x), convert<int32>(int8_y)) |
| // We transform it to: |
| // custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward) |
| // |
| // We will error out, if the pattern is not found for integer |
| // convolution. |
| const auto is_int8_to_int32_cast = |
| [](const HloInstruction* instr) -> bool { |
| return (instr->opcode() == HloOpcode::kConvert && |
| instr->operand(0)->shape().element_type() == S8 && |
| instr->shape().element_type() == S32); |
| }; |
| HloInstruction* input_convert = conv->mutable_operand(0); |
| HloInstruction* kernel_convert = conv->mutable_operand(1); |
| if (conv->shape().element_type() != S32 || |
| !is_int8_to_int32_cast(input_convert) || |
| !is_int8_to_int32_cast(kernel_convert)) { |
| return Unimplemented( |
| "Integer convolutions for CuDNN must have this pattern: " |
| "conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), " |
| "convert<int32>(int8_y))"); |
| } |
| // Bypass the convert<int32> for both inputs. |
| TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape( |
| 0, input_convert->mutable_operand(0))); |
| TF_RETURN_IF_ERROR( |
| conv->parent()->RemoveInstructionAndUnusedOperands(input_convert)); |
| TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape( |
| 1, kernel_convert->mutable_operand(0))); |
| TF_RETURN_IF_ERROR( |
| conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert)); |
| } |
| return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(), |
| conv->mutable_operand(0), conv->mutable_operand(1), |
| conv->window(), conv->convolution_dimension_numbers(), |
| conv->feature_group_count(), conv->metadata()); |
| } |
| |
| return nullptr; |
| } |
| |
| // Tries to rewrite a single convolution into a call to cudnn/miopen. |
| StatusOr<bool> RunOnInstruction(HloInstruction* conv) { |
| CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); |
| |
| TF_ASSIGN_OR_RETURN(HloInstruction * custom_call, |
| CreateCustomCallHelper(conv)); |
| if (custom_call == nullptr) { |
| return false; |
| } |
| |
| TF_RETURN_IF_ERROR( |
| custom_call->set_backend_config(GetDefaultBackendConfig())); |
| |
| VLOG(1) << "Replacing convolution " << conv->ToString() << " with " |
| << custom_call->ToString(); |
| |
| // The CustomCall returns a tuple (conv_result, scratch_memory). Extract |
| // out the conv result and replace `conv` with it. |
| TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( |
| conv, |
| HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); |
| return true; |
| } |
| |
| // Rewrites the convolutions in the given computation into calls to |
| // cudnn/miopen. |
| // Returns true if it made any changes. |
| StatusOr<bool> RunOnComputation(HloComputation* computation) { |
| std::vector<HloInstruction*> convs; |
| for (auto* hlo : computation->instructions()) { |
| if (hlo->opcode() == HloOpcode::kConvolution) { |
| convs.push_back(hlo); |
| } |
| } |
| |
| bool changed = false; |
| for (HloInstruction* conv : convs) { |
| TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); |
| changed |= result; |
| } |
| return changed; |
| } |
| } // namespace |
| |
| StatusOr<bool> GpuConvRewriter::Run(HloModule* module) { |
| bool changed = false; |
| for (HloComputation* computation : module->MakeNonfusionComputations()) { |
| TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); |
| changed |= result; |
| } |
| return changed; |
| } |
| |
| } // namespace gpu |
| } // namespace xla |