blob: d703a9cc630928bd57a7d3065245b79174be4996 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rpc_ops.py."""
import threading
import time
import numpy as np
import portpicker
from tensorflow.python.distribute.experimental.rpc import rpc_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as eager_def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@test_util.with_eager_op_as_function
class RpcOpsTest(test.TestCase):
def setUp(self):
super(RpcOpsTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
def test_generated_rpc_ops(self):
@eager_def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtypes.int32),
tensor_spec.TensorSpec([], dtypes.int32)
])
def remote_fn(a, b):
return math_ops.multiply(a, b)
concrete_remote_fn = remote_fn.get_concrete_function()
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server_resource = rpc_ops.gen_rpc_ops.rpc_server(server_address=address)
rpc_ops.gen_rpc_ops.rpc_server_register(
server_resource,
f=concrete_remote_fn,
captured_inputs=concrete_remote_fn.captured_inputs,
output_specs=rpc_ops.get_output_specs_from_function(concrete_remote_fn),
method_name="multiply")
rpc_ops.gen_rpc_ops.rpc_server_start(server_resource)
client_handle, _ = rpc_ops.gen_rpc_ops.rpc_client(
server_address=address, timeout_in_ms=5000)
future_resource, deleter = rpc_ops.gen_rpc_ops.rpc_call(
client_handle, args=[a, b], method_name="multiply", timeout_in_ms=0)
error_code, _ = rpc_ops.gen_rpc_ops.rpc_check_status(future_resource)
self.assertAllEqual(error_code, 0)
self.assertAllEqual(
rpc_ops.gen_rpc_ops.rpc_get_value(future_resource, Tout=[dtypes.int32]),
[6])
resource_variable_ops.EagerResourceDeleter(
handle=server_resource, handle_device=server_resource.device)
resource_variable_ops.EagerResourceDeleter(
handle=client_handle, handle_device=client_handle.device)
rpc_ops.gen_rpc_ops.delete_rpc_future_resource(future_resource, deleter)
def test_exported_rpc_api_static_factory(self):
@eager_def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtypes.int32),
tensor_spec.TensorSpec([], dtypes.int32)
])
def _remote_fn(a, b):
return math_ops.multiply(a, b)
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server_resource = rpc_ops.Server.create("grpc", address)
server_resource.register("multiply", _remote_fn)
server_resource.start()
client = rpc_ops.Client.create("grpc", address=address, name="test_client")
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
mul_or = client.call(
args=[a, b],
method_name="multiply",
output_specs=tensor_spec.TensorSpec((), dtypes.int32))
self.assertAllEqual(mul_or.is_ok(), True)
self.assertAllEqual(mul_or.get_value(), 6)
# Test empty client name
client1 = rpc_ops.Client.create("grpc", address)
mul_or = client1.call(
args=[a, b],
method_name="multiply",
output_specs=tensor_spec.TensorSpec((), dtypes.int32))
self.assertAllEqual(mul_or.is_ok(), True)
self.assertAllEqual(mul_or.get_value(), 6)
# Test without output_spec
mul_or = client1.multiply(a, b)
self.assertAllEqual(mul_or.is_ok(), True)
self.assertAllEqual(mul_or.get_value(), 6)
self.assertEqual(client1.multiply.__doc__,
"RPC Call for multiply method to server " + address)
def test_rpc_ops_wrapper(self):
@eager_def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtypes.int32),
tensor_spec.TensorSpec([], dtypes.int32)
])
def _remote_fn(a, b):
return math_ops.multiply(a, b)
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server_resource = rpc_ops.GrpcServer(address)
@eager_def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtypes.int32),
tensor_spec.TensorSpec([], dtypes.int32)
])
def add_fn(a, b):
return math_ops.add(a, b)
# Register TF function
server_resource.register("multiply", _remote_fn)
# Register concrete Function
server_resource.register("add", add_fn.get_concrete_function())
server_resource.start()
client = rpc_ops.GrpcClient(address=address, name="test_client")
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
mul_or = client.call(
args=[a, b],
method_name="multiply",
output_specs=tensor_spec.TensorSpec((), dtypes.int32))
self.assertAllEqual(mul_or.is_ok(), True)
self.assertAllEqual(mul_or.get_value(), 6)
add_or = client.call(
args=[a, b],
method_name="add",
output_specs=tensor_spec.TensorSpec((), dtypes.int32))
self.assertAllEqual(add_or.is_ok(), True)
self.assertAllEqual(add_or.get_value(), 5)
# Test empty client name
client1 = rpc_ops.GrpcClient(address, list_registered_methods=True)
mul_or = client1.call(
args=[a, b],
method_name="multiply",
output_specs=tensor_spec.TensorSpec((), dtypes.int32))
self.assertAllEqual(mul_or.is_ok(), True)
self.assertAllEqual(mul_or.get_value(), 6)
# Test without output_spec
mul_or = client1.multiply(a, b)
self.assertAllEqual(mul_or.is_ok(), True)
self.assertAllEqual(mul_or.get_value(), 6)
self.assertEqual(client1.multiply.__doc__,
"RPC Call for multiply method to server " + address)
def test_output_specs(self):
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def test_dict(val):
return {"key": val}
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def is_positive(a):
if a > 0:
return True
return False
@eager_def_function.function(input_signature=[])
def do_nothing():
return []
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def test_nested_structure(v):
return {"test": (v, [v, v]), "test1": (v,)}
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server_resource = rpc_ops.GrpcServer(address)
server_resource.register("test_dict", test_dict)
server_resource.register("is_positive", is_positive)
server_resource.register("test_nested_structure", test_nested_structure)
server_resource.register("do_nothing", do_nothing)
server_resource.start()
client = rpc_ops.GrpcClient(
address=address, name="test_client", list_registered_methods=True)
a = variables.Variable(2, dtype=dtypes.int32)
result_or = client.test_dict(a)
self.assertAllEqual(result_or.is_ok(), True)
nest.map_structure(self.assertAllEqual, result_or.get_value(), {"key": 2})
self.assertTrue(client.is_positive(a))
result_or = client.test_nested_structure(a)
self.assertAllEqual(result_or.is_ok(), True)
nest.map_structure(self.assertAllEqual, result_or.get_value(), {
"test": (2, [2, 2]),
"test1": (2,)
})
result_or = client.do_nothing()
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(result_or.get_value(), [])
def test_input_specs(self):
@eager_def_function.function(input_signature=[{
"a": tensor_spec.TensorSpec([], dtypes.int32),
"b": tensor_spec.TensorSpec([], dtypes.int32)
}])
def test_input_dict(value):
return math_ops.add(value["a"], value["b"])
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server_resource = rpc_ops.GrpcServer(address)
server_resource.register("test_input_dict", test_input_dict)
server_resource.start()
client = rpc_ops.GrpcClient(
address=address, name="test_client", list_registered_methods=True)
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
result_or = client.test_input_dict({"a": a, "b": b})
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(result_or.get_value(), 5)
with self.assertRaises(TypeError):
client.test_input_dict([a, b])
def test_call_register_ordering(self):
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
# Create client succeeds before server start and registration
client = rpc_ops.GrpcClient(address)
# Create client with list_registered_methods fails before server is started.
with self.assertRaises(errors.DeadlineExceededError):
rpc_ops.GrpcClient(
address,
name="client1",
list_registered_methods=True,
timeout_in_ms=1)
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign_add(a):
v.assign_add(a)
@eager_def_function.function(input_signature=[])
def read_var():
return v.value()
server = rpc_ops.GrpcServer(address)
def start_server():
# Delay server start to test whether client creation also waits
# till server is up.
time.sleep(1)
server.register("assign_add", assign_add)
server.start()
t = threading.Thread(target=start_server)
t.start()
# Create same "client1" again should succeed.
client1_with_listed_methods = rpc_ops.GrpcClient(
address, name="client1", list_registered_methods=True)
result_or = client1_with_listed_methods.assign_add(
variables.Variable(2, dtype=dtypes.int64))
self.assertAllEqual(result_or.is_ok(), True)
result_or = client.call("assign_add",
[variables.Variable(2, dtype=dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
# Create client with registered methods
client2_with_listed_methods = rpc_ops.GrpcClient(
address=address, name="client2", list_registered_methods=True)
result_or = client2_with_listed_methods.assign_add(
variables.Variable(2, dtype=dtypes.int64))
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(v, 6)
# Register new method after server started.
with self.assertRaisesRegex(
errors.FailedPreconditionError,
"All methods must be registered before starting the server"):
server.register("read_var", read_var)
def test_client_timeout(self):
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
@eager_def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtypes.int32),
tensor_spec.TensorSpec([], dtypes.int32)
])
def add(a, b):
return math_ops.add(a, b)
server = rpc_ops.GrpcServer(address)
def start_server():
# Delay server start to simulate deadline exceeded for 1st RPC call
# response. Client waits till server is started, thus it can trigger
# deadline exceeded.
time.sleep(1)
server.register("add", add)
server.start()
t = threading.Thread(target=start_server)
t.start()
# Create client with list_registered_methods fails before server is started.
with self.assertRaises(errors.DeadlineExceededError):
rpc_ops.GrpcClient(
address,
name="client1",
list_registered_methods=True,
timeout_in_ms=1)
# Create same client again should succeed with
# list_registered_methods=False. Default timeout for client is 1 ms.
client = rpc_ops.GrpcClient(
address, name="client1", list_registered_methods=False, timeout_in_ms=1)
# Make explicit RPC call, the default timeout of 1 ms should lead to
# deadline exceeded error.
result_or = client.call(
"add", [constant_op.constant(20),
constant_op.constant(30)])
self.assertAllEqual(result_or.is_ok(), False)
error_code, _ = result_or.get_error()
self.assertAllEqual(error_code, errors.DEADLINE_EXCEEDED)
# Specifying reasonable timeout for call should succeed.
result_or = client.call(
"add", [constant_op.constant(20),
constant_op.constant(30)],
timeout_in_ms=5000)
self.assertAllEqual(result_or.is_ok(), True)
error_code, _ = result_or.get_error()
# Test timeouts for convenience methods
# Client with no default timeout.
client = rpc_ops.GrpcClient(
address, name="client2", list_registered_methods=True)
# Restart server again with delay to simulate deadline exceeded.
del server
server = rpc_ops.GrpcServer(address)
t = threading.Thread(target=start_server)
t.start()
# Call fails with 1 ms timeout.
result_or = client.add(
constant_op.constant(20), constant_op.constant(30), timeout_in_ms=1)
self.assertAllEqual(result_or.is_ok(), False)
error_code, _ = result_or.get_error()
self.assertAllEqual(error_code, errors.DEADLINE_EXCEEDED)
# Succeeds with reasonable timeout.
result_or = client.add(
constant_op.constant(20), constant_op.constant(30), timeout_in_ms=5000)
self.assertAllEqual(result_or.is_ok(), True)
def test_async_call_op_wrapper(self):
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign_add(a):
v.assign_add(a)
@eager_def_function.function(input_signature=[])
def read_var():
return v.value()
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("assign_add", assign_add)
server.register("read_var", read_var)
server.start()
client = rpc_ops.GrpcClient(address)
futures = []
for _ in range(10):
futures.append(
client.call("assign_add",
[variables.Variable(2, dtype=dtypes.int64)]))
for f in futures:
f.is_ok()
result_or = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(result_or.get_value(), [20])
def test_rpc_call_op_in_tf_function(self):
@eager_def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtypes.int32),
tensor_spec.TensorSpec([], dtypes.int32)
])
def _remote_fn(a, b):
return math_ops.multiply(a, b)
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server_resource = rpc_ops.GrpcServer(address)
server_resource.register("remote_fn", _remote_fn)
server_resource.start()
client = rpc_ops.GrpcClient(address=address, name="test_client")
a = variables.Variable(2, dtype=dtypes.int32)
b = variables.Variable(3, dtype=dtypes.int32)
@eager_def_function.function
def call_fn():
result_or = client.call(
args=[a, b],
method_name="remote_fn",
output_specs=[tensor_spec.TensorSpec([], dtypes.int32)])
self.assertAllEqual(True, result_or.is_ok())
result = result_or.get_value()
self.assertEqual(len(result), 1) # Call returns a list(tensors)
# TODO(ishark): Shape for output tensor is unknown currently.
# Add attribute for capturing TensorSpec for output and enable
# check below:
# self.assertIsNotNone(result[0].shape.rank)
return result
self.assertAllEqual(call_fn(), [6])
def test_resource_deletion(self):
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server_handle = server._server_handle
# Test Future resource deletion
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(input_signature=[])
def read_var():
return v.value()
server.register("read_var", read_var)
server.start()
client = rpc_ops.GrpcClient(address)
client_handle = client._client_handle
# Check future resource deletion without calling get_value.
def _create_and_delete_rpc_future():
handle = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
return handle._status_or
@eager_def_function.function
def _create_and_delete_rpc_future_fn():
handle = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
return handle._status_or
for _ in range(2):
handle = _create_and_delete_rpc_future()
with self.assertRaises(errors.NotFoundError):
resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)
for _ in range(2):
handle = _create_and_delete_rpc_future_fn()
with self.assertRaises(errors.NotFoundError):
resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)
# Check future resource deletion with calling get_value.
def _create_and_delete_with_future():
handle = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
status_or_handle = handle._status_or
handle.get_value()
return status_or_handle
# Check future resource deletion with calling get_value with tf.function.
@eager_def_function.function
def _create_and_delete_with_future_fn():
handle = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
status_or_handle = handle._status_or
handle.get_value()
return status_or_handle
for _ in range(2):
resource_handle = _create_and_delete_with_future()
with self.assertRaises(errors.NotFoundError):
resource_variable_ops.destroy_resource_op(
resource_handle, ignore_lookup_error=False)
for _ in range(2):
resource_handle = _create_and_delete_with_future_fn()
with self.assertRaises(errors.NotFoundError):
resource_variable_ops.destroy_resource_op(
resource_handle, ignore_lookup_error=False)
# Test server client resource gets deleted.
del client
with self.assertRaises(errors.NotFoundError):
resource_variable_ops.destroy_resource_op(
client_handle, ignore_lookup_error=False)
# Test server server resource gets deleted.
del server
with self.assertRaises(errors.NotFoundError):
resource_variable_ops.destroy_resource_op(
server_handle, ignore_lookup_error=False)
def test_rpc_error(self):
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign_add(a):
v.assign_add(a)
@eager_def_function.function(input_signature=[])
def read_var():
return v.value()
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("assign_add", assign_add)
server.register("read_var", read_var)
server.start()
client = rpc_ops.GrpcClient(address, list_registered_methods=True)
# confirm it works as expected when arguments are passed.
result_or = client.call("assign_add",
[variables.Variable(2, dtype=dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
result_or = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(result_or.get_value(), [2])
result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64))
self.assertAllEqual(True, result_or.is_ok())
result_or = client.read_var()
self.assertAllEqual(True, result_or.is_ok())
self.assertAllEqual(result_or.get_value(), 4)
# Fails with invalid argument error when no arguments are passed.
result_or = client.call("assign_add")
self.assertAllEqual(result_or.is_ok(), False)
error_code, _ = result_or.get_error()
self.assertAllEqual(error_code, errors.INVALID_ARGUMENT)
def test_captured_inputs(self):
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign_add(a):
v.assign_add(a)
@eager_def_function.function(input_signature=[])
def read_var():
return v.value()
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("assign_add", assign_add)
server.register("read_var", read_var)
server.start()
client = rpc_ops.GrpcClient(address)
result_or = client.call("assign_add",
[variables.Variable(2, dtype=dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
result_or = client.call("assign_add",
[variables.Variable(2, dtype=dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
result_or = client.call(
"read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(result_or.get_value(), [4])
def test_register_method_twice(self):
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign_add(a):
v.assign_add(a)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign(a):
v.assign(a)
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("assign", assign_add)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"assign is already registered."):
# Reusing the same error name.
server.register("assign", assign)
def test_tf_function_register_without_input_signature(self):
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function
def assign(a):
v.assign(a)
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
with self.assertRaisesRegex(
ValueError, "Input signature not specified for the function."):
server.register("assign", assign)
# Register without input signature should work for functions without input
# args.
@eager_def_function.function
def read_var():
return v.value()
server.register("read_var", read_var)
def test_multi_device_resource(self):
elements = np.random.randint(100, size=[200])
with ops.device("/device:CPU:1"):
queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
@eager_def_function.function()
def populate_queue():
queue.enqueue_many(elements)
queue.close()
with ops.device("/device:CPU:0"):
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("populate_queue", populate_queue)
server.start()
client = rpc_ops.GrpcClient(address, list_registered_methods=True)
client.populate_queue()
for e in elements:
self.assertAllEqual(e, queue.dequeue())
def test_queue_resource(self):
elements = np.random.randint(100, size=[200])
queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
@eager_def_function.function()
def populate_queue():
queue.enqueue_many(elements)
queue.close()
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("populate_queue", populate_queue)
server.start()
client = rpc_ops.GrpcClient(address, list_registered_methods=True)
client.populate_queue()
for e in elements:
self.assertAllEqual(e, queue.dequeue())
def test_multi_device_resource_cpu(self):
with ops.device("/device:cpu:1"):
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@eager_def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
def assign_add(a):
v.assign_add(a)
with ops.device("/device:CPU:0"):
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
server = rpc_ops.GrpcServer(address)
server.register("assign_add", assign_add)
server.start()
client = rpc_ops.GrpcClient(address, list_registered_methods=True)
result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64))
self.assertAllEqual(result_or.is_ok(), True)
self.assertAllEqual(v, 2)
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()