Add support for class_weight to DataHandler.

PiperOrigin-RevId: 286450504
Change-Id: Ib8bfeab71edc41ba99bdc8b78c104fad0eb4904a
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 13fb866..6b65726 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -1148,7 +1148,6 @@
 
   # TODO(omalleyt): Handle `validation_split` with separate utility.
   # TODO(omalleyt): Handle `validation_data` batch size when `x` is a gen.
-  # TODO(omalleyt): Handle `class_weight` in `DataAdapter`s.
   def __init__(self,
                x,
                y=None,
@@ -1158,6 +1157,7 @@
                initial_epoch=0,
                epochs=1,
                shuffle=False,
+               class_weight=None,
                max_queue_size=10,
                workers=1,
                use_multiprocessing=False):
@@ -1182,6 +1182,8 @@
 
     strategy = ds_context.get_strategy()
     dataset = self._train_adapter.get_dataset()
+    if class_weight:
+      dataset = dataset.map(_make_class_weight_map_fn(class_weight))
     self._train_dataset = strategy.experimental_distribute_dataset(dataset)
     self._steps_per_epoch = self._infer_steps(steps_per_epoch)
 
@@ -1252,3 +1254,55 @@
     if size >= 0:
       return size
     return None
+
+
+def _make_class_weight_map_fn(class_weight):
+  """Applies class weighting to a `Dataset`.
+
+  The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
+  `y` must be a single `Tensor`.
+
+  Arguments:
+    class_weight: A map where the keys are integer class ids and values are
+      the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
+
+  Returns:
+    A function that can be used with `tf.data.Dataset.map` to apply class
+    weighting.
+  """
+  class_ids = list(sorted(class_weight.keys()))
+  expected_class_ids = list(range(len(class_ids)))
+  if class_ids != expected_class_ids:
+    error_msg = (
+        "Expected `class_weight` to be a dict with keys from 0 to one less "
+        "than the number of classes, found {}").format(class_weight)
+    raise ValueError(error_msg)
+
+  class_weight_tensor = ops.convert_to_tensor(
+      [class_weight[c] for c in class_ids])
+
+  def _class_weights_map_fn(*data):
+    """Convert `class_weight` to `sample_weight`."""
+    if len(data) == 2:
+      x, y = data
+      sw = None
+    else:
+      x, y, sw = data
+
+    if nest.is_sequence(y):
+      raise ValueError(
+          "`class_weight` is only supported for `Model`s with a single output.")
+
+    cw = array_ops.gather_v2(class_weight_tensor, y)
+    if sw is not None:
+      cw = math_ops.cast(cw, sw.dtype)
+      if len(cw.shape.as_list()) > len(sw.shape.as_list()):
+        cw = array_ops.squeeze(cw)
+      # `class_weight` and `sample_weight` are multiplicative.
+      sw = sw * cw
+    else:
+      sw = cw
+
+    return x, y, sw
+
+  return _class_weights_map_fn
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index 5b0f119..8ada2f5 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -963,6 +963,77 @@
     self.assertEqual(returned_data, [[([0],), ([1],),
                                       ([2],)], [([0],), ([1],), ([2],)]])
 
+  def test_class_weight(self):
+    data_handler = data_adapter.DataHandler(
+        x=[[0], [1], [2]],
+        y=[[2], [1], [0]],
+        class_weight={
+            0: 0.5,
+            1: 1.,
+            2: 1.5
+        },
+        epochs=2,
+        steps_per_epoch=3)
+    returned_data = []
+    for _, iterator in data_handler.enumerate_epochs():
+      epoch_data = []
+      for _ in data_handler.steps():
+        epoch_data.append(next(iterator))
+      returned_data.append(epoch_data)
+    returned_data = self.evaluate(returned_data)
+    self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [1.]),
+                                      ([2], [0], [0.5])],
+                                     [([0], [2], [1.5]), ([1], [1], [1.]),
+                                      ([2], [0], [0.5])]])
+
+  def test_class_weight_and_sample_weight(self):
+    data_handler = data_adapter.DataHandler(
+        x=[[0], [1], [2]],
+        y=[[2], [1], [0]],
+        sample_weight=[[1.], [2.], [4.]],
+        class_weight={
+            0: 0.5,
+            1: 1.,
+            2: 1.5
+        },
+        epochs=2,
+        steps_per_epoch=3)
+    returned_data = []
+    for _, iterator in data_handler.enumerate_epochs():
+      epoch_data = []
+      for _ in data_handler.steps():
+        epoch_data.append(next(iterator))
+      returned_data.append(epoch_data)
+    returned_data = self.evaluate(returned_data)
+    self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [2.]),
+                                      ([2], [0], [2.])],
+                                     [([0], [2], [1.5]), ([1], [1], [2.]),
+                                      ([2], [0], [2.])]])
+
+  def test_class_weight_user_errors(self):
+    with self.assertRaisesRegexp(ValueError, 'to be a dict with keys'):
+      data_adapter.DataHandler(
+          x=[[0], [1], [2]],
+          y=[[2], [1], [0]],
+          batch_size=1,
+          sample_weight=[[1.], [2.], [4.]],
+          class_weight={
+              0: 0.5,
+              1: 1.,
+              3: 1.5  # Skips class `2`.
+          })
+
+    with self.assertRaisesRegexp(ValueError, 'with a single output'):
+      data_adapter.DataHandler(
+          x=np.ones((10, 1)),
+          y=[np.ones((10, 1)), np.zeros((10, 1))],
+          batch_size=2,
+          class_weight={
+              0: 0.5,
+              1: 1.,
+              2: 1.5
+          })
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()