Clean up OpKernel 2.
Move DeviceNumaNode to DeviceBase.
PiperOrigin-RevId: 291064902
Change-Id: I1c19e92e6e44ed43fbd0b313391781e9668f2367
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index b6696df..b8890dd 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -22,6 +22,7 @@
#include "absl/base/macros.h"
#include "absl/strings/string_view.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -234,6 +235,7 @@
// Unimplemented by default
virtual const DeviceAttributes& attributes() const;
+ virtual int NumaNode() const { return attributes().locality().numa_node(); }
virtual const string& name() const;
// Materializes the given TensorProto into 'tensor' stored in Device
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 753188a..76cee64 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -137,12 +137,6 @@
const string& OpKernel::requested_device() const { return def_->device(); }
const string& OpKernel::requested_input(int i) const { return def_->input(i); }
-// This static function exists only because device_attributes.pb.h is
-// already included here, and it can't be introduced elsewhere.
-/*static*/ int OpKernel::DeviceNumaNode(const DeviceBase* device) {
- return device->attributes().locality().numa_node();
-}
-
Status OpKernel::InputRange(StringPiece input_name, int* start,
int* stop) const {
const auto result = input_name_map_.find(input_name);
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e92da08..1196f79 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -220,8 +220,6 @@
// TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
Status MakeShape(const Tensor& shape, TensorShape* out) const;
- static int DeviceNumaNode(const DeviceBase* device);
-
// Returns `true` if and only if this kernel uses deferred execution.
bool is_deferred() const { return is_deferred_; }
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 8ae61b9..5c62e51 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -269,7 +269,7 @@
const Variant& v = input.scalar<Variant>()();
// DT_VARIANT tensors must be allocated on CPU since they wrap C++
// objects which can not be efficiently represented in GPU memory.
- int numa_node = DeviceNumaNode(ctx->device());
+ int numa_node = ctx->device()->NumaNode();
Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index f2d0a95..c8ac410 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -331,7 +331,7 @@
const Variant& v = inp.scalar<Variant>()();
Variant v_out;
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
- int numa_node = DeviceNumaNode(ctx->device());
+ int numa_node = ctx->device()->NumaNode();
Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
out.scalar<Variant>()() = std::move(v_out);
ctx->set_output(0, std::move(out));