Add support for class_weight to DataHandler.
PiperOrigin-RevId: 286450504
Change-Id: Ib8bfeab71edc41ba99bdc8b78c104fad0eb4904a
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 13fb866..6b65726 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -1148,7 +1148,6 @@
# TODO(omalleyt): Handle `validation_split` with separate utility.
# TODO(omalleyt): Handle `validation_data` batch size when `x` is a gen.
- # TODO(omalleyt): Handle `class_weight` in `DataAdapter`s.
def __init__(self,
x,
y=None,
@@ -1158,6 +1157,7 @@
initial_epoch=0,
epochs=1,
shuffle=False,
+ class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
@@ -1182,6 +1182,8 @@
strategy = ds_context.get_strategy()
dataset = self._train_adapter.get_dataset()
+ if class_weight:
+ dataset = dataset.map(_make_class_weight_map_fn(class_weight))
self._train_dataset = strategy.experimental_distribute_dataset(dataset)
self._steps_per_epoch = self._infer_steps(steps_per_epoch)
@@ -1252,3 +1254,55 @@
if size >= 0:
return size
return None
+
+
+def _make_class_weight_map_fn(class_weight):
+ """Applies class weighting to a `Dataset`.
+
+ The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
+ `y` must be a single `Tensor`.
+
+ Arguments:
+ class_weight: A map where the keys are integer class ids and values are
+ the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
+
+ Returns:
+ A function that can be used with `tf.data.Dataset.map` to apply class
+ weighting.
+ """
+ class_ids = list(sorted(class_weight.keys()))
+ expected_class_ids = list(range(len(class_ids)))
+ if class_ids != expected_class_ids:
+ error_msg = (
+ "Expected `class_weight` to be a dict with keys from 0 to one less "
+ "than the number of classes, found {}").format(class_weight)
+ raise ValueError(error_msg)
+
+ class_weight_tensor = ops.convert_to_tensor(
+ [class_weight[c] for c in class_ids])
+
+ def _class_weights_map_fn(*data):
+ """Convert `class_weight` to `sample_weight`."""
+ if len(data) == 2:
+ x, y = data
+ sw = None
+ else:
+ x, y, sw = data
+
+ if nest.is_sequence(y):
+ raise ValueError(
+ "`class_weight` is only supported for `Model`s with a single output.")
+
+ cw = array_ops.gather_v2(class_weight_tensor, y)
+ if sw is not None:
+ cw = math_ops.cast(cw, sw.dtype)
+ if len(cw.shape.as_list()) > len(sw.shape.as_list()):
+ cw = array_ops.squeeze(cw)
+ # `class_weight` and `sample_weight` are multiplicative.
+ sw = sw * cw
+ else:
+ sw = cw
+
+ return x, y, sw
+
+ return _class_weights_map_fn
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index 5b0f119..8ada2f5 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -963,6 +963,77 @@
self.assertEqual(returned_data, [[([0],), ([1],),
([2],)], [([0],), ([1],), ([2],)]])
+ def test_class_weight(self):
+ data_handler = data_adapter.DataHandler(
+ x=[[0], [1], [2]],
+ y=[[2], [1], [0]],
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 2: 1.5
+ },
+ epochs=2,
+ steps_per_epoch=3)
+ returned_data = []
+ for _, iterator in data_handler.enumerate_epochs():
+ epoch_data = []
+ for _ in data_handler.steps():
+ epoch_data.append(next(iterator))
+ returned_data.append(epoch_data)
+ returned_data = self.evaluate(returned_data)
+ self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [1.]),
+ ([2], [0], [0.5])],
+ [([0], [2], [1.5]), ([1], [1], [1.]),
+ ([2], [0], [0.5])]])
+
+ def test_class_weight_and_sample_weight(self):
+ data_handler = data_adapter.DataHandler(
+ x=[[0], [1], [2]],
+ y=[[2], [1], [0]],
+ sample_weight=[[1.], [2.], [4.]],
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 2: 1.5
+ },
+ epochs=2,
+ steps_per_epoch=3)
+ returned_data = []
+ for _, iterator in data_handler.enumerate_epochs():
+ epoch_data = []
+ for _ in data_handler.steps():
+ epoch_data.append(next(iterator))
+ returned_data.append(epoch_data)
+ returned_data = self.evaluate(returned_data)
+ self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [2.]),
+ ([2], [0], [2.])],
+ [([0], [2], [1.5]), ([1], [1], [2.]),
+ ([2], [0], [2.])]])
+
+ def test_class_weight_user_errors(self):
+ with self.assertRaisesRegexp(ValueError, 'to be a dict with keys'):
+ data_adapter.DataHandler(
+ x=[[0], [1], [2]],
+ y=[[2], [1], [0]],
+ batch_size=1,
+ sample_weight=[[1.], [2.], [4.]],
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 3: 1.5 # Skips class `2`.
+ })
+
+ with self.assertRaisesRegexp(ValueError, 'with a single output'):
+ data_adapter.DataHandler(
+ x=np.ones((10, 1)),
+ y=[np.ones((10, 1)), np.zeros((10, 1))],
+ batch_size=2,
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 2: 1.5
+ })
+
if __name__ == '__main__':
ops.enable_eager_execution()