Add RandomHeight and RandomWidth image preproc layer.

PiperOrigin-RevId: 293494681
Change-Id: Ia60494b8fddba8c271062987d8d3ca1064a5ceeb
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
index 8148f95..0cc56ba 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
@@ -960,6 +960,182 @@
     return dict(list(base_config.items()) + list(config.items()))
 
 
+class RandomHeight(Layer):
+  """Randomly vary the height of a batch of images during training.
+
+  Adjusts the height of a batch of images by a random factor. The input
+  should be a 4-D tensor in the "channels_last" image data format.
+
+  By default, this layer is inactive during inference.
+
+  Arguments:
+    factor: A positive float (fraction of original height), or a tuple of
+      size 2 representing lower and upper bound for resizing vertically. When
+      represented as a single float, this value is used for both the upper and
+      lower bound. For instance, `factor=(0.2, 0.3)` results in an output height
+      varying in the range `[original + 20%, original + 30%]`. `factor=(-0.2,
+      0.3)` results in an output height varying in the range `[original - 20%,
+      original + 30%]`. `factor=0.2` results in an output height varying in the
+      range `[original - 20%, original + 20%]`.
+    interpolation: String, the interpolation method. Defaults to `bilinear`.
+      Supports `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
+      `gaussian`, `mitchellcubic`
+    seed: Integer. Used to create a random seed.
+
+  Input shape:
+    4D tensor with shape:
+    `(samples, height, width, channels)` (data_format='channels_last').
+
+  Output shape:
+    4D tensor with shape:
+    `(samples, random_height, width, channels)`.
+  """
+
+  def __init__(self, factor, interpolation='bilinear', seed=None, **kwargs):
+    self.factor = factor
+    if isinstance(factor, (tuple, list)):
+      self.height_lower = -factor[0]
+      self.height_upper = factor[1]
+    else:
+      self.height_lower = self.height_upper = factor
+    if self.height_lower > 1.:
+      raise ValueError('`factor` cannot have abs lower bound larger than 1.0, '
+                       'got {}'.format(factor))
+    self.interpolation = interpolation
+    self._interpolation_method = get_interpolation(interpolation)
+    self.input_spec = InputSpec(ndim=4)
+    self.seed = seed
+    self._rng = make_generator(self.seed)
+    super(RandomHeight, self).__init__(**kwargs)
+
+  def call(self, inputs, training=None):
+    if training is None:
+      training = K.learning_phase()
+
+    def random_height_inputs():
+      """Inputs height-adjusted with random ops."""
+      inputs_shape = array_ops.shape(inputs)
+      h_axis, w_axis = 1, 2
+      img_hd = math_ops.cast(inputs_shape[h_axis], dtypes.float32)
+      img_wd = inputs_shape[w_axis]
+      height_factor = self._rng.uniform(
+          shape=[],
+          minval=(1.0 - self.height_lower),
+          maxval=(1.0 + self.height_upper))
+      adjusted_height = math_ops.cast(height_factor * img_hd, dtypes.int32)
+      adjusted_size = array_ops.stack([adjusted_height, img_wd])
+      output = image_ops.resize_images_v2(
+          images=inputs, size=adjusted_size, method=self._interpolation_method)
+      original_shape = inputs.shape.as_list()
+      output_shape = [original_shape[0]] + [None] + original_shape[2:4]
+      output.set_shape(output_shape)
+      return output
+
+    return tf_utils.smart_cond(training, random_height_inputs, lambda: inputs)
+
+  def compute_output_shape(self, input_shape):
+    input_shape = tensor_shape.TensorShape(input_shape).as_list()
+    return tensor_shape.TensorShape(
+        [input_shape[0], None, input_shape[2], input_shape[3]])
+
+  def get_config(self):
+    config = {
+        'factor': self.factor,
+        'interpolation': self.interpolation,
+        'seed': self.seed,
+    }
+    base_config = super(RandomHeight, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
+
+class RandomWidth(Layer):
+  """Randomly vary the width of a batch of images during training.
+
+  Adjusts the width of a batch of images by a random factor. The input
+  should be a 4-D tensor in the "channels_last" image data format.
+
+  By default, this layer is inactive during inference.
+
+  Arguments:
+    factor: A positive float (fraction of original width), or a tuple of
+      size 2 representing lower and upper bound for resizing horizontally. When
+      represented as a single float, this value is used for both the upper and
+      lower bound. For instance, `factor=(0.2, 0.3)` results in an output width
+      varying in the range `[original + 20%, original + 30%]`. `factor=(-0.2,
+      0.3)` results in an output width varying in the range `[original - 20%,
+      original + 30%]`. `factor=0.2` results in an output width varying in the
+      range `[original - 20%, original + 20%]`.
+    interpolation: String, the interpolation method. Defaults to `bilinear`.
+      Supports `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
+      `gaussian`, `mitchellcubic`
+    seed: Integer. Used to create a random seed.
+
+  Input shape:
+    4D tensor with shape:
+    `(samples, height, width, channels)` (data_format='channels_last').
+
+  Output shape:
+    4D tensor with shape:
+    `(samples, random_height, width, channels)`.
+  """
+
+  def __init__(self, factor, interpolation='bilinear', seed=None, **kwargs):
+    self.factor = factor
+    if isinstance(factor, (tuple, list)):
+      self.width_lower = -factor[0]
+      self.width_upper = factor[1]
+    else:
+      self.width_lower = self.width_upper = factor
+    if self.width_lower > 1.:
+      raise ValueError('`factor` cannot have abs lower bound larger than 1.0, '
+                       'got {}'.format(factor))
+    self.interpolation = interpolation
+    self._interpolation_method = get_interpolation(interpolation)
+    self.input_spec = InputSpec(ndim=4)
+    self.seed = seed
+    self._rng = make_generator(self.seed)
+    super(RandomWidth, self).__init__(**kwargs)
+
+  def call(self, inputs, training=None):
+    if training is None:
+      training = K.learning_phase()
+
+    def random_width_inputs():
+      """Inputs width-adjusted with random ops."""
+      inputs_shape = array_ops.shape(inputs)
+      h_axis, w_axis = 1, 2
+      img_hd = inputs_shape[h_axis]
+      img_wd = math_ops.cast(inputs_shape[w_axis], dtypes.float32)
+      width_factor = self._rng.uniform(
+          shape=[],
+          minval=(1.0 - self.width_lower),
+          maxval=(1.0 + self.width_upper))
+      adjusted_width = math_ops.cast(width_factor * img_wd, dtypes.int32)
+      adjusted_size = array_ops.stack([img_hd, adjusted_width])
+      output = image_ops.resize_images_v2(
+          images=inputs, size=adjusted_size, method=self._interpolation_method)
+      original_shape = inputs.shape.as_list()
+      output_shape = original_shape[0:2] + [None] + [original_shape[3]]
+      output.set_shape(output_shape)
+      return output
+
+    return tf_utils.smart_cond(training, random_width_inputs, lambda: inputs)
+
+  def compute_output_shape(self, input_shape):
+    input_shape = tensor_shape.TensorShape(input_shape).as_list()
+    return tensor_shape.TensorShape(
+        [input_shape[0], input_shape[1], None, input_shape[3]])
+
+  def get_config(self):
+    config = {
+        'factor': self.factor,
+        'interpolation': self.interpolation,
+        'seed': self.seed,
+    }
+    base_config = super(RandomWidth, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
+
 def make_generator(seed=None):
   if seed:
     return stateful_random_ops.Generator.from_seed(seed)
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
index c3ba19c..861e9fa 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
@@ -607,5 +607,115 @@
     self.assertEqual(layer_1.name, layer.name)
 
 
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class RandomHeightTest(keras_parameterized.TestCase):
+
+  def _run_test(self, factor):
+    np.random.seed(1337)
+    num_samples = 2
+    orig_height = 5
+    orig_width = 8
+    channels = 3
+    with tf_test_util.use_gpu():
+      img = np.random.random((num_samples, orig_height, orig_width, channels))
+      layer = image_preprocessing.RandomHeight(factor)
+      img_out = layer(img, training=True)
+      self.assertEqual(img_out.shape[0], 2)
+      self.assertEqual(img_out.shape[2], 8)
+      self.assertEqual(img_out.shape[3], 3)
+
+  @parameterized.named_parameters(('random_height_4_by_6', (.4, .6)),
+                                  ('random_height_3_by_2', (.3, 1.2)),
+                                  ('random_height_3', .3))
+  def test_random_height_basic(self, factor):
+    self._run_test(factor)
+
+  def test_valid_random_height(self):
+    # need (maxval - minval) * rnd + minval = 0.6
+    mock_factor = 0
+    with test.mock.patch.object(
+        gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
+      with tf_test_util.use_gpu():
+        img = np.random.random((12, 5, 8, 3))
+        layer = image_preprocessing.RandomHeight(.4)
+        img_out = layer(img, training=True)
+        self.assertEqual(img_out.shape[1], 3)
+
+  def test_random_height_invalid_factor(self):
+    with self.assertRaises(ValueError):
+      image_preprocessing.RandomHeight((-1.5, .4))
+
+  def test_random_height_inference(self):
+    with CustomObjectScope({'RandomHeight': image_preprocessing.RandomHeight}):
+      input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
+      expected_output = input_images
+      with tf_test_util.use_gpu():
+        layer = image_preprocessing.RandomHeight(.5)
+        actual_output = layer(input_images, training=0)
+        self.assertAllClose(expected_output, actual_output)
+
+  @tf_test_util.run_v2_only
+  def test_config_with_custom_name(self):
+    layer = image_preprocessing.RandomHeight(.5, name='image_preproc')
+    config = layer.get_config()
+    layer_1 = image_preprocessing.RandomHeight.from_config(config)
+    self.assertEqual(layer_1.name, layer.name)
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class RandomWidthTest(keras_parameterized.TestCase):
+
+  def _run_test(self, factor):
+    np.random.seed(1337)
+    num_samples = 2
+    orig_height = 5
+    orig_width = 8
+    channels = 3
+    with tf_test_util.use_gpu():
+      img = np.random.random((num_samples, orig_height, orig_width, channels))
+      layer = image_preprocessing.RandomWidth(factor)
+      img_out = layer(img, training=True)
+      self.assertEqual(img_out.shape[0], 2)
+      self.assertEqual(img_out.shape[1], 5)
+      self.assertEqual(img_out.shape[3], 3)
+
+  @parameterized.named_parameters(('random_width_4_by_6', (.4, .6)),
+                                  ('random_width_3_by_2', (.3, 1.2)),
+                                  ('random_width_3', .3))
+  def test_random_width_basic(self, factor):
+    self._run_test(factor)
+
+  def test_valid_random_width(self):
+    # need (maxval - minval) * rnd + minval = 0.6
+    mock_factor = 0
+    with test.mock.patch.object(
+        gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
+      with tf_test_util.use_gpu():
+        img = np.random.random((12, 8, 5, 3))
+        layer = image_preprocessing.RandomWidth(.4)
+        img_out = layer(img, training=True)
+        self.assertEqual(img_out.shape[2], 3)
+
+  def test_random_width_invalid_factor(self):
+    with self.assertRaises(ValueError):
+      image_preprocessing.RandomWidth((-1.5, .4))
+
+  def test_random_width_inference(self):
+    with CustomObjectScope({'RandomWidth': image_preprocessing.RandomWidth}):
+      input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
+      expected_output = input_images
+      with tf_test_util.use_gpu():
+        layer = image_preprocessing.RandomWidth(.5)
+        actual_output = layer(input_images, training=0)
+        self.assertAllClose(expected_output, actual_output)
+
+  @tf_test_util.run_v2_only
+  def test_config_with_custom_name(self):
+    layer = image_preprocessing.RandomWidth(.5, name='image_preproc')
+    config = layer.get_config()
+    layer_1 = image_preprocessing.RandomWidth.from_config(config)
+    self.assertEqual(layer_1.name, layer.name)
+
+
 if __name__ == '__main__':
   test.main()