blob: 08908104f4d1139168daf0ea5cbe34b13990e065 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Rate."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rate import rate
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class RateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testBuildRate(self):
m = rate.Rate()
m.build(
constant_op.constant([1], dtype=dtypes.float32),
constant_op.constant([2], dtype=dtypes.float32))
old_numer = m.numer
m(
constant_op.constant([2], dtype=dtypes.float32),
constant_op.constant([2], dtype=dtypes.float32))
self.assertTrue(old_numer is m.numer)
@test_util.run_in_graph_and_eager_modes()
def testBasic(self):
with self.test_session():
r_ = rate.Rate()
a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
self.evaluate(variables.global_variables_initializer())
self.evaluate(variables.local_variables_initializer())
self.assertEqual([[1]], self.evaluate(a))
b = r_(constant_op.constant([2]), denominator=constant_op.constant([2]))
self.assertEqual([[1]], self.evaluate(b))
c = r_(constant_op.constant([4]), denominator=constant_op.constant([3]))
self.assertEqual([[2]], self.evaluate(c))
d = r_(constant_op.constant([16]), denominator=constant_op.constant([3]))
self.assertEqual([[0]], self.evaluate(d)) # divide by 0
def testNamesWithSpaces(self):
m1 = rate.Rate(name="has space")
m1(array_ops.ones([1]), array_ops.ones([1]))
self.assertEqual(m1.name, "has space")
self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0")
@test_util.run_in_graph_and_eager_modes()
def testWhileLoop(self):
with self.test_session():
r_ = rate.Rate()
def body(value, denom, i, ret_rate):
i += 1
ret_rate = r_(value, denom)
with ops.control_dependencies([ret_rate]):
value = math_ops.add(value, 2)
denom = math_ops.add(denom, 1)
return [value, denom, i, ret_rate]
def condition(v, d, i, r):
del v, d, r # unused vars by condition
return math_ops.less(i, 100)
i = constant_op.constant(0)
value = constant_op.constant([1], dtype=dtypes.float64)
denom = constant_op.constant([1], dtype=dtypes.float64)
ret_rate = r_(value, denom)
self.evaluate(variables.global_variables_initializer())
self.evaluate(variables.local_variables_initializer())
loop = control_flow_ops.while_loop(condition, body,
[value, denom, i, ret_rate])
self.assertEqual([[2]], self.evaluate(loop[3]))
if __name__ == "__main__":
test.main()