add `TestSubGrad` to `math_grad_test`
diff --git a/tensorflow/c/experimental/gradients/math_grad_test.cc b/tensorflow/c/experimental/gradients/math_grad_test.cc
index f743bb3..983cd72 100644
--- a/tensorflow/c/experimental/gradients/math_grad_test.cc
+++ b/tensorflow/c/experimental/gradients/math_grad_test.cc
@@ -143,6 +143,35 @@
return Status::OK();
}
+Status SubModel(AbstractContext* ctx,
+ absl::Span<AbstractTensorHandle* const> inputs,
+ absl::Span<AbstractTensorHandle*> outputs) {
+ return ops::Sub(ctx, inputs, outputs, "Sub");
+}
+
+Status SubGradModel(AbstractContext* ctx,
+ absl::Span<AbstractTensorHandle* const> inputs,
+ absl::Span<AbstractTensorHandle*> outputs) {
+ GradientRegistry registry;
+ TF_RETURN_IF_ERROR(registry.Register("Sub", SubRegisterer));
+
+ Tape tape(/*persistent=*/false);
+ tape.Watch(inputs[0]);
+ tape.Watch(inputs[1]);
+ std::vector<AbstractTensorHandle*> temp_outputs(1);
+ AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
+ TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
+ absl::MakeSpan(temp_outputs), "SubGrad"));
+
+ TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
+ /*sources=*/inputs,
+ /*output_gradients=*/{}, outputs));
+ for (auto temp_output : temp_outputs) {
+ temp_output->Unref();
+ }
+ return Status::OK();
+}
+
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
@@ -228,6 +257,27 @@
NegModel, NegGradModel, ctx_.get(), {x.get()}, UseFunction()));
}
+TEST_P(CppGradients, TestSubGrad) {
+ AbstractTensorHandlePtr x;
+ {
+ AbstractTensorHandle* x_raw = nullptr;
+ Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
+ ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+ x.reset(x_raw);
+ }
+
+ AbstractTensorHandlePtr y;
+ {
+ AbstractTensorHandle* y_raw = nullptr;
+ Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
+ ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+ y.reset(y_raw);
+ }
+
+ ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
+ SubModel, SubGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
+}
+
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,