blob: 4807f23188ad757af94d902f0cb65ec6e61a17cf [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Distribution tests for keras.layers.preprocessing.category_crossing."""
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import category_crossing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
def batch_wrapper(dataset, batch_size, distribution, repeat=None):
if repeat:
dataset = dataset.repeat(repeat)
# TPUs currently require fully defined input shapes, drop_remainder ensures
# the input will have fully defined shapes.
if backend.is_tpu_strategy(distribution):
return dataset.batch(batch_size, drop_remainder=True)
else:
return dataset.batch(batch_size)
@ds_combinations.generate(
combinations.combine(
# Investigate why crossing is not supported with TPU.
distribution=all_strategies,
mode=['eager', 'graph']))
class CategoryCrossingDistributionTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_distribution(self, distribution):
input_array_1 = np.array([['a', 'b'], ['c', 'd']])
input_array_2 = np.array([['e', 'f'], ['g', 'h']])
inp_dataset = dataset_ops.DatasetV2.from_tensor_slices(
{'input_1': input_array_1, 'input_2': input_array_2})
inp_dataset = batch_wrapper(inp_dataset, 2, distribution)
# pyformat: disable
expected_output = [[b'a_X_e', b'a_X_f', b'b_X_e', b'b_X_f'],
[b'c_X_g', b'c_X_h', b'd_X_g', b'd_X_h']]
config.set_soft_device_placement(True)
with distribution.scope():
input_data_1 = keras.Input(shape=(2,), dtype=dtypes.string,
name='input_1')
input_data_2 = keras.Input(shape=(2,), dtype=dtypes.string,
name='input_2')
input_data = [input_data_1, input_data_2]
layer = category_crossing.CategoryCrossing()
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(inp_dataset)
self.assertAllEqual(expected_output, output_dataset)
if __name__ == '__main__':
test.main()