Add RandomHeight and RandomWidth image preproc layer.
PiperOrigin-RevId: 293494681
Change-Id: Ia60494b8fddba8c271062987d8d3ca1064a5ceeb
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
index 8148f95..0cc56ba 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
@@ -960,6 +960,182 @@
return dict(list(base_config.items()) + list(config.items()))
+class RandomHeight(Layer):
+ """Randomly vary the height of a batch of images during training.
+
+ Adjusts the height of a batch of images by a random factor. The input
+ should be a 4-D tensor in the "channels_last" image data format.
+
+ By default, this layer is inactive during inference.
+
+ Arguments:
+ factor: A positive float (fraction of original height), or a tuple of
+ size 2 representing lower and upper bound for resizing vertically. When
+ represented as a single float, this value is used for both the upper and
+ lower bound. For instance, `factor=(0.2, 0.3)` results in an output height
+ varying in the range `[original + 20%, original + 30%]`. `factor=(-0.2,
+ 0.3)` results in an output height varying in the range `[original - 20%,
+ original + 30%]`. `factor=0.2` results in an output height varying in the
+ range `[original - 20%, original + 20%]`.
+ interpolation: String, the interpolation method. Defaults to `bilinear`.
+ Supports `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
+ `gaussian`, `mitchellcubic`
+ seed: Integer. Used to create a random seed.
+
+ Input shape:
+ 4D tensor with shape:
+ `(samples, height, width, channels)` (data_format='channels_last').
+
+ Output shape:
+ 4D tensor with shape:
+ `(samples, random_height, width, channels)`.
+ """
+
+ def __init__(self, factor, interpolation='bilinear', seed=None, **kwargs):
+ self.factor = factor
+ if isinstance(factor, (tuple, list)):
+ self.height_lower = -factor[0]
+ self.height_upper = factor[1]
+ else:
+ self.height_lower = self.height_upper = factor
+ if self.height_lower > 1.:
+ raise ValueError('`factor` cannot have abs lower bound larger than 1.0, '
+ 'got {}'.format(factor))
+ self.interpolation = interpolation
+ self._interpolation_method = get_interpolation(interpolation)
+ self.input_spec = InputSpec(ndim=4)
+ self.seed = seed
+ self._rng = make_generator(self.seed)
+ super(RandomHeight, self).__init__(**kwargs)
+
+ def call(self, inputs, training=None):
+ if training is None:
+ training = K.learning_phase()
+
+ def random_height_inputs():
+ """Inputs height-adjusted with random ops."""
+ inputs_shape = array_ops.shape(inputs)
+ h_axis, w_axis = 1, 2
+ img_hd = math_ops.cast(inputs_shape[h_axis], dtypes.float32)
+ img_wd = inputs_shape[w_axis]
+ height_factor = self._rng.uniform(
+ shape=[],
+ minval=(1.0 - self.height_lower),
+ maxval=(1.0 + self.height_upper))
+ adjusted_height = math_ops.cast(height_factor * img_hd, dtypes.int32)
+ adjusted_size = array_ops.stack([adjusted_height, img_wd])
+ output = image_ops.resize_images_v2(
+ images=inputs, size=adjusted_size, method=self._interpolation_method)
+ original_shape = inputs.shape.as_list()
+ output_shape = [original_shape[0]] + [None] + original_shape[2:4]
+ output.set_shape(output_shape)
+ return output
+
+ return tf_utils.smart_cond(training, random_height_inputs, lambda: inputs)
+
+ def compute_output_shape(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ return tensor_shape.TensorShape(
+ [input_shape[0], None, input_shape[2], input_shape[3]])
+
+ def get_config(self):
+ config = {
+ 'factor': self.factor,
+ 'interpolation': self.interpolation,
+ 'seed': self.seed,
+ }
+ base_config = super(RandomHeight, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+class RandomWidth(Layer):
+ """Randomly vary the width of a batch of images during training.
+
+ Adjusts the width of a batch of images by a random factor. The input
+ should be a 4-D tensor in the "channels_last" image data format.
+
+ By default, this layer is inactive during inference.
+
+ Arguments:
+ factor: A positive float (fraction of original width), or a tuple of
+ size 2 representing lower and upper bound for resizing horizontally. When
+ represented as a single float, this value is used for both the upper and
+ lower bound. For instance, `factor=(0.2, 0.3)` results in an output width
+ varying in the range `[original + 20%, original + 30%]`. `factor=(-0.2,
+ 0.3)` results in an output width varying in the range `[original - 20%,
+ original + 30%]`. `factor=0.2` results in an output width varying in the
+ range `[original - 20%, original + 20%]`.
+ interpolation: String, the interpolation method. Defaults to `bilinear`.
+ Supports `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
+ `gaussian`, `mitchellcubic`
+ seed: Integer. Used to create a random seed.
+
+ Input shape:
+ 4D tensor with shape:
+ `(samples, height, width, channels)` (data_format='channels_last').
+
+ Output shape:
+ 4D tensor with shape:
+ `(samples, random_height, width, channels)`.
+ """
+
+ def __init__(self, factor, interpolation='bilinear', seed=None, **kwargs):
+ self.factor = factor
+ if isinstance(factor, (tuple, list)):
+ self.width_lower = -factor[0]
+ self.width_upper = factor[1]
+ else:
+ self.width_lower = self.width_upper = factor
+ if self.width_lower > 1.:
+ raise ValueError('`factor` cannot have abs lower bound larger than 1.0, '
+ 'got {}'.format(factor))
+ self.interpolation = interpolation
+ self._interpolation_method = get_interpolation(interpolation)
+ self.input_spec = InputSpec(ndim=4)
+ self.seed = seed
+ self._rng = make_generator(self.seed)
+ super(RandomWidth, self).__init__(**kwargs)
+
+ def call(self, inputs, training=None):
+ if training is None:
+ training = K.learning_phase()
+
+ def random_width_inputs():
+ """Inputs width-adjusted with random ops."""
+ inputs_shape = array_ops.shape(inputs)
+ h_axis, w_axis = 1, 2
+ img_hd = inputs_shape[h_axis]
+ img_wd = math_ops.cast(inputs_shape[w_axis], dtypes.float32)
+ width_factor = self._rng.uniform(
+ shape=[],
+ minval=(1.0 - self.width_lower),
+ maxval=(1.0 + self.width_upper))
+ adjusted_width = math_ops.cast(width_factor * img_wd, dtypes.int32)
+ adjusted_size = array_ops.stack([img_hd, adjusted_width])
+ output = image_ops.resize_images_v2(
+ images=inputs, size=adjusted_size, method=self._interpolation_method)
+ original_shape = inputs.shape.as_list()
+ output_shape = original_shape[0:2] + [None] + [original_shape[3]]
+ output.set_shape(output_shape)
+ return output
+
+ return tf_utils.smart_cond(training, random_width_inputs, lambda: inputs)
+
+ def compute_output_shape(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ return tensor_shape.TensorShape(
+ [input_shape[0], input_shape[1], None, input_shape[3]])
+
+ def get_config(self):
+ config = {
+ 'factor': self.factor,
+ 'interpolation': self.interpolation,
+ 'seed': self.seed,
+ }
+ base_config = super(RandomWidth, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
def make_generator(seed=None):
if seed:
return stateful_random_ops.Generator.from_seed(seed)
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
index c3ba19c..861e9fa 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
@@ -607,5 +607,115 @@
self.assertEqual(layer_1.name, layer.name)
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class RandomHeightTest(keras_parameterized.TestCase):
+
+ def _run_test(self, factor):
+ np.random.seed(1337)
+ num_samples = 2
+ orig_height = 5
+ orig_width = 8
+ channels = 3
+ with tf_test_util.use_gpu():
+ img = np.random.random((num_samples, orig_height, orig_width, channels))
+ layer = image_preprocessing.RandomHeight(factor)
+ img_out = layer(img, training=True)
+ self.assertEqual(img_out.shape[0], 2)
+ self.assertEqual(img_out.shape[2], 8)
+ self.assertEqual(img_out.shape[3], 3)
+
+ @parameterized.named_parameters(('random_height_4_by_6', (.4, .6)),
+ ('random_height_3_by_2', (.3, 1.2)),
+ ('random_height_3', .3))
+ def test_random_height_basic(self, factor):
+ self._run_test(factor)
+
+ def test_valid_random_height(self):
+ # need (maxval - minval) * rnd + minval = 0.6
+ mock_factor = 0
+ with test.mock.patch.object(
+ gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
+ with tf_test_util.use_gpu():
+ img = np.random.random((12, 5, 8, 3))
+ layer = image_preprocessing.RandomHeight(.4)
+ img_out = layer(img, training=True)
+ self.assertEqual(img_out.shape[1], 3)
+
+ def test_random_height_invalid_factor(self):
+ with self.assertRaises(ValueError):
+ image_preprocessing.RandomHeight((-1.5, .4))
+
+ def test_random_height_inference(self):
+ with CustomObjectScope({'RandomHeight': image_preprocessing.RandomHeight}):
+ input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
+ expected_output = input_images
+ with tf_test_util.use_gpu():
+ layer = image_preprocessing.RandomHeight(.5)
+ actual_output = layer(input_images, training=0)
+ self.assertAllClose(expected_output, actual_output)
+
+ @tf_test_util.run_v2_only
+ def test_config_with_custom_name(self):
+ layer = image_preprocessing.RandomHeight(.5, name='image_preproc')
+ config = layer.get_config()
+ layer_1 = image_preprocessing.RandomHeight.from_config(config)
+ self.assertEqual(layer_1.name, layer.name)
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class RandomWidthTest(keras_parameterized.TestCase):
+
+ def _run_test(self, factor):
+ np.random.seed(1337)
+ num_samples = 2
+ orig_height = 5
+ orig_width = 8
+ channels = 3
+ with tf_test_util.use_gpu():
+ img = np.random.random((num_samples, orig_height, orig_width, channels))
+ layer = image_preprocessing.RandomWidth(factor)
+ img_out = layer(img, training=True)
+ self.assertEqual(img_out.shape[0], 2)
+ self.assertEqual(img_out.shape[1], 5)
+ self.assertEqual(img_out.shape[3], 3)
+
+ @parameterized.named_parameters(('random_width_4_by_6', (.4, .6)),
+ ('random_width_3_by_2', (.3, 1.2)),
+ ('random_width_3', .3))
+ def test_random_width_basic(self, factor):
+ self._run_test(factor)
+
+ def test_valid_random_width(self):
+ # need (maxval - minval) * rnd + minval = 0.6
+ mock_factor = 0
+ with test.mock.patch.object(
+ gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
+ with tf_test_util.use_gpu():
+ img = np.random.random((12, 8, 5, 3))
+ layer = image_preprocessing.RandomWidth(.4)
+ img_out = layer(img, training=True)
+ self.assertEqual(img_out.shape[2], 3)
+
+ def test_random_width_invalid_factor(self):
+ with self.assertRaises(ValueError):
+ image_preprocessing.RandomWidth((-1.5, .4))
+
+ def test_random_width_inference(self):
+ with CustomObjectScope({'RandomWidth': image_preprocessing.RandomWidth}):
+ input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
+ expected_output = input_images
+ with tf_test_util.use_gpu():
+ layer = image_preprocessing.RandomWidth(.5)
+ actual_output = layer(input_images, training=0)
+ self.assertAllClose(expected_output, actual_output)
+
+ @tf_test_util.run_v2_only
+ def test_config_with_custom_name(self):
+ layer = image_preprocessing.RandomWidth(.5, name='image_preproc')
+ config = layer.get_config()
+ layer_1 = image_preprocessing.RandomWidth.from_config(config)
+ self.assertEqual(layer_1.name, layer.name)
+
+
if __name__ == '__main__':
test.main()