in ForwardAcc, add test
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 40cfa87..86c9197 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -177,12 +177,12 @@
template <typename Gradient>
class ForwardFunction
: public std::function<Status(const std::vector<Gradient*>&,
- std::vector<Gradient*>*)> {
+ std::vector<Gradient*>*, bool)> {
public:
template <typename lambda_type>
explicit ForwardFunction(lambda_type lambda)
: std::function<Status(const std::vector<Gradient*>&,
- std::vector<Gradient*>*)>(lambda) {}
+ std::vector<Gradient*>*, bool)>(lambda) {}
};
// Computes Jacobian-vector products using forward-mode automatic
@@ -205,8 +205,9 @@
// Does not take ownership of `vspace`, which must outlive the
// ForwardAccumulator.
explicit ForwardAccumulator(
- const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
- : vspace_(vspace) {
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ bool use_batch)
+ : vspace_(vspace), use_batch(use_batch) {
call_state_.emplace(nullptr, false);
}
@@ -314,6 +315,9 @@
// available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
+ //Decides if tangents are vector rised or not
+ bool use_batch;
+
struct AccumulatorCallState {
AccumulatorCallState(
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
@@ -1062,7 +1066,7 @@
output_tensors, backward_function_getter, backward_function_deleter,
in_grads, &forward_grads));
} else {
- TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
+ TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads, use_batch));
}
for (int i = 0; i < forward_grads.size(); ++i) {
if (forward_grads[i] != nullptr) {
diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py
index cd91295..bbe97f5 100644
--- a/tensorflow/python/eager/forwardprop.py
+++ b/tensorflow/python/eager/forwardprop.py
@@ -326,7 +326,7 @@
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)>
"""
- def __init__(self, primals, tangents):
+ def __init__(self, primals, tangents, use_batch=False):
"""Specify tensors to watch and their Jacobian-vector products.
Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
@@ -348,7 +348,7 @@
ValueError: If the same tensor or variable is specified multiple times in
`primals`.
"""
- self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew()
+ self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(use_batch)
self._recording = False
primal_ids = set()
for primal in nest.flatten(primals):
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index ad55a53..b72dac1 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -1,7 +1,4 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
@@ -1009,7 +1006,7 @@
self.assertAllClose(hess_value, hessian_pfor)
-class JacobianTests(test.TestCase, parameterized.TestCase):
+class BatchTests(test.TestCase, parameterized.TestCase):
@parameterized.parameters([(math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10)])
@@ -1020,6 +1017,19 @@
_jvp_batch(f, primals, tangent_batch)[1],
_jvp_batch_matmul(f, primals, *tangent_batch))
+ def testBatchCorrectness(self):
+ x = constant_op.constant(2.0)
+ y = constant_op.constant(5.0)
+ tangents = (
+ constant_op.constant([1., 0., 1.]),
+ constant_op.constant([0., 1., 1.]),
+ )
+ with forwardprop.ForwardAccumulator((x, y), tangents, True) as acc:
+ z = x * y
+ self.assertAllClose(
+ acc.jvp(z),
+ constant_op.constant([5.0, 2.0, 7.0]
+ ))
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a5c9c18..4431502 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -284,7 +284,7 @@
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
// Creates a new forward accumulator. Does not add it to the active set.
-PyObject* TFE_Py_ForwardAccumulatorNew();
+PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch);
// Adds a ForwardAccumulator to the active set, meaning it will watch executed
// operations. It must not already be in the active set.
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index dcaaafe..11bd06e 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -2419,7 +2419,8 @@
tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
PyObject* inputs, PyObject* results,
const std::vector<PyObject*>& input_tangents,
- std::vector<PyObject*>* output_tangents) {
+ std::vector<PyObject*>* output_tangents,
+ bool use_batch) {
if (forward_gradient_function == nullptr) {
return tensorflow::errors::Internal(
"No forward gradient function registered.");
@@ -2430,9 +2431,10 @@
// Normalize the input sequence to a tuple so it works with function
// caching; otherwise it may be an opaque _InputList object.
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
+ PyObject* to_batch = (use_batch) ? Py_True : Py_False;
tensorflow::Safe_PyObjectPtr callback_args(
- Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
- py_input_tangents.get()));
+ Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results,
+ py_input_tangents.get(), to_batch));
tensorflow::Safe_PyObjectPtr py_result(
PyObject_CallObject(forward_gradient_function, callback_args.get()));
if (py_result == nullptr || PyErr_Occurred()) {
@@ -2555,7 +2557,7 @@
} else {
tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
[forward_function](const std::vector<PyObject*>& input_tangents,
- std::vector<PyObject*>* output_tangents) {
+ std::vector<PyObject*>* output_tangents, bool use_batch=false) {
return CallOpSpecificJVPFunction(forward_function, input_tangents,
output_tangents);
});
@@ -2797,7 +2799,7 @@
return PyList_New(0);
}
-PyObject* TFE_Py_ForwardAccumulatorNew() {
+PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) {
TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
TFE_Py_ForwardAccumulator* accumulator =
@@ -2808,7 +2810,7 @@
"ForwardAccumulator requires a PyVSpace to be registered."),
nullptr);
}
- accumulator->accumulator = new ForwardAccumulator(*py_vspace);
+ accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch);
return reinterpret_cast<PyObject*>(accumulator);
}
@@ -3166,9 +3168,9 @@
tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
[op_name, attrs, inputs, results](
const std::vector<PyObject*>& input_tangents,
- std::vector<PyObject*>* output_tangents) {
+ std::vector<PyObject*>* output_tangents, bool use_batch) {
return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
- output_tangents);
+ output_tangents, use_batch);
});
tensorflow::eager::ForwardFunction<PyObject>* forward_function;
if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index bf11faa..5572cfb 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -730,8 +730,8 @@
});
// TFE_Py_ForwardAccumulator logic.
- m.def("TFE_Py_ForwardAccumulatorNew", []() {
- return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
+ m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
+ return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
});
m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {