in ForwardAcc, add test
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 40cfa87..86c9197 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -177,12 +177,12 @@
 template <typename Gradient>
 class ForwardFunction
     : public std::function<Status(const std::vector<Gradient*>&,
-                                  std::vector<Gradient*>*)> {
+                                  std::vector<Gradient*>*, bool)> {
  public:
   template <typename lambda_type>
   explicit ForwardFunction(lambda_type lambda)
       : std::function<Status(const std::vector<Gradient*>&,
-                             std::vector<Gradient*>*)>(lambda) {}
+                             std::vector<Gradient*>*, bool)>(lambda) {}
 };
 
 // Computes Jacobian-vector products using forward-mode automatic
@@ -205,8 +205,9 @@
   // Does not take ownership of `vspace`, which must outlive the
   // ForwardAccumulator.
   explicit ForwardAccumulator(
-      const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
-      : vspace_(vspace) {
+      const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+      bool use_batch)
+      : vspace_(vspace), use_batch(use_batch) {
     call_state_.emplace(nullptr, false);
   }
 
@@ -314,6 +315,9 @@
   // available in language bindings (e.g. Python).
   const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
 
+  //Decides if tangents are vector rised or not
+  bool use_batch;
+
   struct AccumulatorCallState {
     AccumulatorCallState(
         GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
@@ -1062,7 +1066,7 @@
         output_tensors, backward_function_getter, backward_function_deleter,
         in_grads, &forward_grads));
   } else {
-    TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
+    TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads, use_batch));
   }
   for (int i = 0; i < forward_grads.size(); ++i) {
     if (forward_grads[i] != nullptr) {
diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py
index cd91295..bbe97f5 100644
--- a/tensorflow/python/eager/forwardprop.py
+++ b/tensorflow/python/eager/forwardprop.py
@@ -326,7 +326,7 @@
   <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)>
   """
 
-  def __init__(self, primals, tangents):
+  def __init__(self, primals, tangents, use_batch=False):
     """Specify tensors to watch and their Jacobian-vector products.
 
     Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
@@ -348,7 +348,7 @@
       ValueError: If the same tensor or variable is specified multiple times in
         `primals`.
     """
-    self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
+    self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(use_batch)
     self._recording = False
     primal_ids = set()
     for primal in nest.flatten(primals):
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index ad55a53..b72dac1 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -1,7 +1,4 @@
 # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
@@ -1009,7 +1006,7 @@
     self.assertAllClose(hess_value, hessian_pfor)
 
 
-class JacobianTests(test.TestCase, parameterized.TestCase):
+class BatchTests(test.TestCase, parameterized.TestCase):
 
   @parameterized.parameters([(math_ops.sin, (2, 3), 5),
                              (math_ops.sin, (2, 3, 4), 10)])
@@ -1020,6 +1017,19 @@
         _jvp_batch(f, primals, tangent_batch)[1],
         _jvp_batch_matmul(f, primals, *tangent_batch))
 
+  def testBatchCorrectness(self):
+    x = constant_op.constant(2.0)
+    y = constant_op.constant(5.0)
+    tangents = (
+      constant_op.constant([1., 0., 1.]),
+      constant_op.constant([0., 1., 1.]),
+    )
+    with forwardprop.ForwardAccumulator((x, y), tangents, True) as acc:
+      z = x * y
+    self.assertAllClose(
+      acc.jvp(z),
+      constant_op.constant([5.0, 2.0, 7.0]
+    ))
 
 if __name__ == "__main__":
   # TODO(allenl): Also test with 1.x-style graph mode.
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a5c9c18..4431502 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -284,7 +284,7 @@
 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
 
 // Creates a new forward accumulator. Does not add it to the active set.
-PyObject* TFE_Py_ForwardAccumulatorNew();
+PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch);
 
 // Adds a ForwardAccumulator to the active set, meaning it will watch executed
 // operations. It must not already be in the active set.
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index dcaaafe..11bd06e 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -2419,7 +2419,8 @@
 tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
                                    PyObject* inputs, PyObject* results,
                                    const std::vector<PyObject*>& input_tangents,
-                                   std::vector<PyObject*>* output_tangents) {
+                                   std::vector<PyObject*>* output_tangents,
+                                   bool use_batch) {
   if (forward_gradient_function == nullptr) {
     return tensorflow::errors::Internal(
         "No forward gradient function registered.");
@@ -2430,9 +2431,10 @@
   // Normalize the input sequence to a tuple so it works with function
   // caching; otherwise it may be an opaque _InputList object.
   tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
+  PyObject* to_batch = (use_batch) ? Py_True : Py_False;
   tensorflow::Safe_PyObjectPtr callback_args(
-      Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
-                    py_input_tangents.get()));
+      Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
+                    py_input_tangents.get(), to_batch));
   tensorflow::Safe_PyObjectPtr py_result(
       PyObject_CallObject(forward_gradient_function, callback_args.get()));
   if (py_result == nullptr || PyErr_Occurred()) {
@@ -2555,7 +2557,7 @@
   } else {
     tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
         [forward_function](const std::vector<PyObject*>& input_tangents,
-                           std::vector<PyObject*>* output_tangents) {
+                           std::vector<PyObject*>* output_tangents, bool use_batch=false) {
           return CallOpSpecificJVPFunction(forward_function, input_tangents,
                                            output_tangents);
         });
@@ -2797,7 +2799,7 @@
   return PyList_New(0);
 }
 
-PyObject* TFE_Py_ForwardAccumulatorNew() {
+PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
   TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
   if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
   TFE_Py_ForwardAccumulator* accumulator =
@@ -2808,7 +2810,7 @@
             "ForwardAccumulator requires a PyVSpace to be registered."),
         nullptr);
   }
-  accumulator->accumulator = new ForwardAccumulator(*py_vspace);
+  accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
   return reinterpret_cast<PyObject*>(accumulator);
 }
 
@@ -3166,9 +3168,9 @@
   tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
       [op_name, attrs, inputs, results](
           const std::vector<PyObject*>& input_tangents,
-          std::vector<PyObject*>* output_tangents) {
+          std::vector<PyObject*>* output_tangents, bool use_batch) {
         return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
-                               output_tangents);
+                               output_tangents, use_batch);
       });
   tensorflow::eager::ForwardFunction<PyObject>* forward_function;
   if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index bf11faa..5572cfb 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -730,8 +730,8 @@
         });
 
   // TFE_Py_ForwardAccumulator logic.
-  m.def("TFE_Py_ForwardAccumulatorNew", []() {
-    return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
+  m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
+    return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
   });
 
   m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {