Make keras sequential work with tfnumpy.ndarray
PiperOrigin-RevId: 323718791
Change-Id: I1000d893e92144136abcc93e5c2718f9835065a4
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index ac6c1a9..a63d4994 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -71,6 +71,7 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
@@ -929,7 +930,8 @@
call_context = base_layer_utils.call_context()
# Accept NumPy and scalar inputs by converting to Tensors.
- if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
+ if any(isinstance(x, (
+ np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
inputs = nest.map_structure(_convert_numpy_or_python_types, inputs)
input_list = nest.flatten(inputs)
@@ -997,12 +999,13 @@
call_context = base_layer_utils.call_context()
# Accept NumPy and scalar inputs by converting to Tensors.
- if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
+ if any(isinstance(x, (
+ np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
def _convert_non_tensor(x):
# Don't call `ops.convert_to_tensor_v2` on all `inputs` because
# `SparseTensors` can't be converted to `Tensor`.
- if isinstance(x, (np.ndarray, float, int)):
+ if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
return ops.convert_to_tensor_v2(x)
return x
@@ -3242,7 +3245,7 @@
def _convert_numpy_or_python_types(x):
- if isinstance(x, (np.ndarray, float, int)):
+ if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
return ops.convert_to_tensor_v2(x)
return x
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 979c1e4..5957576 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -520,6 +520,8 @@
def _get_shape_tuple(t):
if hasattr(t, 'shape'):
shape = t.shape
+ if isinstance(shape, tuple):
+ return shape
if shape.rank is not None:
return tuple(shape.as_list())
return None
diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py
index 9074f37..c0cd8e8 100644
--- a/tensorflow/python/ops/numpy_ops/np_interop_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py
@@ -273,6 +273,33 @@
self.assertIsInstance(result, np.ndarray)
self.assertAllClose(result, onp.square(values))
+ def testKerasInteropSequential(self):
+ class ProjectionLayer(tf.keras.layers.Layer):
+ """Linear projection layer using TF NumPy."""
+
+ def __init__(self, units):
+ super(ProjectionLayer, self).__init__()
+ self._units = units
+
+ def build(self, input_shape):
+ stddev = np.sqrt(self._units).astype(np.float32)
+ initial_value = np.random.randn(input_shape[1], self._units).astype(
+ np.float32) / stddev
+ # Note that TF NumPy can interoperate with tf.Variable.
+ self.w = tf.Variable(initial_value, trainable=True)
+
+ def call(self, inputs):
+ return np.matmul(inputs, self.w)
+
+ model = tf.keras.Sequential(
+ [tf.keras.layers.Dense(100), ProjectionLayer(2)])
+ output = model.call(np.random.randn(10, 100))
+
+ self.assertIsInstance(output, np.ndarray)
+
+ dense_layer = tf.keras.layers.Dense(100)
+ output = dense_layer(np.random.randn(10, 100))
+
def testPForInterop(self):
def outer_product(a):
return np.tensordot(a, a, 0)