blob: 79ea97d1e97783a2ef4797ea1611557566dd245d [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace grappler {
class ArithmeticOptimizerTest : public GrapplerTest {
protected:
// Optimize a graph using optimizer and prune all the nodes that no
// longer have any output consumers.
void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
GraphDef* output) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
// Run optimizer twice to make sure the rewrite is idempotent.
void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer,
GrapplerItem* item, GraphDef* output) {
TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
// Run optimizer twice to make sure the rewrite is idempotent.
void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item,
GraphDef* output) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
}
// Run optimizer twice to make sure the rewrite is idempotent.
// Optionally run a constant folding pass before pruning.
void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
GraphDef* output, bool const_folding = false) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
if (const_folding) {
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
.Optimize(nullptr, *item, output));
}
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
optimizer->options_.combine_add_to_addn = false;
}
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.combine_add_to_addn = true;
}
void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_conjugate_into_transpose = true;
}
void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_multiply_into_conv = true;
}
void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_transpose_into_matmul = true;
}
void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
}
void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.minimize_broadcasts = true;
}
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_identity_transpose = true;
}
void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_involution = true;
}
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_bitcast = true;
}
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_cast = true;
}
void EnableOnlyReorderReshapeAroundUnary(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.reorder_reshape_around_unary = true;
}
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_reshape = true;
}
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.reorder_cast_like_and_value_preserving = true;
}
void EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_mul_with_tile = true;
}
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_mul_with_square = true;
}
void EnableOnlyReplacePackWithTileReshape(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_pack_with_tile_reshape = true;
}
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
}
void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}
void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_log_softmax = true;
}
void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_pow = true;
}
void EnableOnlyFuseSquaredDiff(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fuse_squared_diff = true;
}
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
}
void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_logical_not = true;
}
void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.simplify_aggregation = true;
}
void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_log1p = true;
}
void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_expm1 = true;
}
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
}
void EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_stack_slice_same_axis = true;
}
void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.simplify_embedding_lookup = true;
}
void EnableOnlyRemoveCastIntoSegmentReduction(
ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_cast_into_segment_reduction = true;
}
private:
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.dedup_computations = false;
options.combine_add_to_addn = false;
options.convert_sqrt_div_to_rsqrt_mul = false;
options.convert_pow = false;
options.convert_log1p = false;
options.optimize_max_or_min_of_monotonic = false;
options.fold_conjugate_into_transpose = false;
options.fold_multiply_into_conv = false;
options.fold_transpose_into_matmul = false;
options.hoist_common_factor_out_of_aggregation = false;
options.hoist_cwise_unary_chains = false;
options.minimize_broadcasts = false;
options.remove_identity_transpose = false;
options.remove_involution = false;
options.remove_idempotent = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
options.remove_redundant_reshape = false;
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_like_and_value_preserving = false;
options.reorder_reshape_around_unary = false;
options.replace_mul_with_tile = false;
options.replace_mul_with_square = false;
options.simplify_aggregation = false;
options.unary_ops_composition = false;
options.simplify_embedding_lookup = false;
options.remove_cast_into_segment_reduction = false;
optimizer->options_ = options;
}
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_