Change cpp gradient function signature to accept a `Context` which contains the local context when calling the gradient function e.g. the default graph.
In python gradients, this state is managed via singletons(global context manager stacks) and is implicitly captured.

PiperOrigin-RevId: 322611049
Change-Id: I26fe086a687e4989a96f18baed1ccfe2ec7b7c1e
diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc
index 3a7a628..f5085fd 100644
--- a/tensorflow/c/eager/gradients.cc
+++ b/tensorflow/c/eager/gradients.cc
@@ -175,7 +175,8 @@
     gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
     std::vector<AbstractTensorHandle*>* result) const {
   if (backward_function == nullptr) return Status::OK();
-  return backward_function->Compute(output_gradients, result);
+  Context ctx = {ctx_};
+  return backward_function->Compute(&ctx, output_gradients, result);
 }
 
 // Looks up the ID of a Gradient.
diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h
index e09b6ff..267ee5b 100644
--- a/tensorflow/c/eager/gradients.h
+++ b/tensorflow/c/eager/gradients.h
@@ -31,7 +31,8 @@
 //
 // class AddGradientFunction : public GradientFunction {
 //  public:
-//   Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
+//   Status Compute(Context* ctx,
+//                  absl::Span<AbstractTensorHandle* const> grad_inputs,
 //                  std::vector<AbstractTensorHandle*>* grad_outputs) override {
 //     grad_outputs->resize(2);
 //     (*grad_outputs)[0] = grad_inputs[0];
@@ -50,11 +51,16 @@
 // Status RegisterGradients(GradientRegistry* registry) {
 //   return registry->Register("Add", AddRegisterer);
 // }
+struct Context {
+ public:
+  AbstractContext* ctx;
+};
 class GradientFunction {
  public:
   // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
   // `grad_inputs`.
-  virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
+  virtual Status Compute(Context* ctx,
+                         absl::Span<AbstractTensorHandle* const> grad_inputs,
                          std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
   virtual ~GradientFunction() {}
 };
diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc
index e27cbb2..47bd8cc 100644
--- a/tensorflow/c/experimental/gradients/math_grad.cc
+++ b/tensorflow/c/experimental/gradients/math_grad.cc
@@ -24,31 +24,30 @@
 
 class AddGradientFunction : public GradientFunction {
  public:
-  explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
-  Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
+  Status Compute(Context* ctx,
+                 absl::Span<AbstractTensorHandle* const> grad_inputs,
                  std::vector<AbstractTensorHandle*>* grad_outputs) override {
     grad_outputs->resize(2);
     std::vector<AbstractTensorHandle*> identity_outputs(1);
     // TODO(b/145674566): Handle name unification in tracing code.
     // TODO(b/161805092): Support broadcasting.
-    TF_RETURN_IF_ERROR(ops::Identity(
-        ctx_, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), "Identity0"));
+    TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
+                                     absl::MakeSpan(identity_outputs),
+                                     "Identity0"));
     (*grad_outputs)[0] = identity_outputs[0];
-    TF_RETURN_IF_ERROR(ops::Identity(
-        ctx_, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), "Identity1"));
+    TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
+                                     absl::MakeSpan(identity_outputs),
+                                     "Identity1"));
     (*grad_outputs)[1] = identity_outputs[0];
     return Status::OK();
   }
   ~AddGradientFunction() override {}
-
- private:
-  AbstractContext* ctx_;
 };
 
 }  // namespace
 
 GradientFunction* AddRegisterer(const ForwardOperation& op) {
-  return new AddGradientFunction(op.ctx);
+  return new AddGradientFunction;
 }
 }  // namespace gradients
 }  // namespace tensorflow