blob: 01ee8703a93836d607ee9b765c51c79fe3bb974f [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the currently experimental in-graph batch ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
def delayed_plus1(x):
"""Sleeps for 100ms then returns x+1."""
time.sleep(0.1)
return x + 1
class BatchOpsTest(test.TestCase):
"""Tests for batch_ops.{un,}batch."""
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
index_t = thread_results[1]
empty_b = main_results[0][0]
empty_m = main_results[1]
else:
batch_t = main_results[0][0]
index_t = main_results[1]
empty_b = thread_results[0][0]
empty_m = thread_results[1]
# Check that both the inputs made it out exactly once.
self.assertAllEqual(sorted(batch_t), (1, 2))
# Check that we get 2 rows in the index tensor.
self.assertEqual(len(index_t), 2)
# Check that the other ones are empty.
self.assertEqual(len(empty_b), 0)
self.assertEqual(len(empty_m), 0)
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[5, 10],
grad_timeout_micros=0, batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1, 3]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
else:
batch_t = main_results[0][0]
# Check that the batch tensor incorporates the padding.
self.assertEqual(len(batch_t), 5)
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
[inp0, inp1],
num_batch_threads=1,
max_batch_size=2,
batch_timeout_micros=36000000,
grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched], feed_dict={inp0: [1],
inp1: [2]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0]
empty_t = main_results[0]
else:
batch_t = main_results[0]
empty_t = thread_results[0]
# Assert that the tensors were batched together.
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
self.assertAllEqual(empty_t[0], [])
self.assertAllEqual(empty_t[1], [])
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
with self.assertRaises(Exception) as raised:
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
self.assertGreater(
raised.exception.message.find("must have equal 0th-dimension size"),
0)
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
grad_timeout_micros=0, batching_queue="")
computation = batched[0] + 1
result = batch_ops.unbatch(computation, index, id_t,
timeout_micros=1000000, shared_name="unbatch")
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
with self.cached_session() as sess:
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + captured_inp0 - captured_inp1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = gen_batch_ops.batch_function(
[inp],
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000,
Tout=[dtypes.int32],
f=computation,
captured_tensors=computation.captured_inputs)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32)
def computation(inp):
return inp + captured_inp0 - captured_inp1
result = gen_batch_ops.batch_function(
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
batching_queue="",
f=computation,
in_tensors=[inp],
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
def computation(in0, in1):
return in0 + in1
result = gen_batch_ops.batch_function(
[inp], # computation actually expects 2 inputs.
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
batching_queue="",
f=computation,
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
with self.assertRaisesRegexp(InvalidArgumentError,
".*2 arguments.*but 1.*"):
sess.run([result], feed_dict={inp: [2]})
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return array_ops.reshape(in_t, [-1]) + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [[2]]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
computation = batched[0] + 1
timeout_micros = 10
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
shared_name="shared_unbatch")
# Set up a parallel pipeline that delays the computation, but uses the
# same unbatch resource object as the non-delayed pipeline.
computation_delayed = script_ops.py_func(delayed_plus1,
[batched[0]],
dtypes.int32)
result_delayed = batch_ops.unbatch(computation_delayed,
index,
id_t,
timeout_micros,
shared_name="shared_unbatch")
thread_results = []
def worker():
# A first call using the non-delayed pipeline. The batcher will send an
# empty tensor along the non-delayed pipeline.
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
time.sleep(0.1) # Ensure the thread's call starts first.
# A second call using the delayed pipeline. The batcher will send the
# batched tensor along the delayed pipeline, thus delaying the arrival of
# the batched tensor at the unbatch op, relative to the empty tensor.
#
# TODO(olston, apassos): Avoid relying on the order in which the batch op
# emits the empty tensor versus the batched one.
_ = sess.run([result_delayed], feed_dict={inp: [2]})
worker_thread.join()
# The thread's call should hit the timeout, and thus get 0 results.
self.assertEqual(len(thread_results), 0)
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=1000000,
batching_queue="")
computation = batched[0] * batched[0]
result = batch_ops.unbatch(computation, index, id_t,
timeout_micros=1000000, shared_name="unbatch")
grad = gradients_impl.gradients(result, inp)
thread_results = []
def worker():
thread_results.extend(sess.run([grad], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([grad], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [4])
if __name__ == "__main__":
test.main()