blob: 1c9c81827e0f251c8ae7bc47242334fb202835ac [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.tensor_forest.ops.tensor_forest."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from google.protobuf.json_format import ParseDict
from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
class TensorForestTest(test_util.TensorFlowTestCase):
def testForestHParams(self):
hparams = tensor_forest.ForestHParams(
num_classes=2,
num_trees=100,
max_nodes=1000,
split_after_samples=25,
num_features=60).fill()
self.assertEquals(2, hparams.num_classes)
self.assertEquals(3, hparams.num_output_columns)
self.assertEquals(10, hparams.num_splits_to_consider)
# Default value of valid_leaf_threshold
self.assertEquals(1, hparams.valid_leaf_threshold)
self.assertEquals(0, hparams.base_random_seed)
def testForestHParamsBigTree(self):
hparams = tensor_forest.ForestHParams(
num_classes=2,
num_trees=100,
max_nodes=1000000,
split_after_samples=25,
num_features=1000).fill()
self.assertEquals(31, hparams.num_splits_to_consider)
def testForestHParamsStringParams(self):
hparams = tensor_forest.ForestHParams(
num_classes=2,
num_trees=100,
max_nodes=1000000,
split_after_samples="25",
num_splits_to_consider="1000000",
num_features=1000).fill()
self.assertEquals("1000000", hparams.num_splits_to_consider)
def testTrainingConstructionClassification(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
input_labels = [0, 1, 2, 3]
params = tensor_forest.ForestHParams(
num_classes=4,
num_features=2,
num_trees=10,
max_nodes=1000,
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
graph = graph_builder.training_graph(input_data, input_labels)
self.assertTrue(isinstance(graph, ops.Operation))
def testTrainingConstructionRegression(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
input_labels = [0, 1, 2, 3]
params = tensor_forest.ForestHParams(
num_classes=4,
num_features=2,
num_trees=10,
max_nodes=1000,
split_after_samples=25,
regression=True).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
graph = graph_builder.training_graph(input_data, input_labels)
self.assertTrue(isinstance(graph, ops.Operation))
def testInferenceConstruction(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
params = tensor_forest.ForestHParams(
num_classes=4,
num_features=2,
num_trees=10,
max_nodes=1000,
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
def testInfrenceFromRestoredModel(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
expected_prediction = [[0.0, 1.0], [0.0, 1.0],
[0.0, 1.0], [0.0, 1.0]]
hparams = tensor_forest.ForestHParams(
num_classes=2,
num_features=2,
num_trees=1,
max_nodes=1000,
split_after_samples=25).fill()
tree_weight = {'decisionTree':
{'nodes':
[{'binaryNode':
{'rightChildId': 2,
'leftChildId': 1,
'inequalityLeftChildTest':
{'featureId': {'id': '0'},
'threshold': {'floatValue': 0}}}},
{'leaf': {'vector':
{'value': [{'floatValue': 0.0},
{'floatValue': 1.0}]}},
'nodeId': 1},
{'leaf': {'vector':
{'value': [{'floatValue': 0.0},
{'floatValue': 1.0}]}},
'nodeId': 2}]}}
restored_tree_param = ParseDict(tree_weight,
_tree_proto.Model()).SerializeToString()
graph_builder = tensor_forest.RandomForestGraphs(hparams,
[restored_tree_param])
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
with self.test_session():
variables.global_variables_initializer().run()
resources.initialize_resources(resources.shared_resources()).run()
self.assertEquals(probs.eval().shape, (4, 2))
self.assertEquals(probs.eval().tolist(), expected_prediction)
def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],
values=[-1.0, 0.0, -1., 2., 1., -2.0],
dense_shape=[4, 10])
input_labels = [0, 1, 2, 3]
params = tensor_forest.ForestHParams(
num_classes=4,
num_features=10,
num_trees=10,
max_nodes=1000,
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
graph = graph_builder.training_graph(input_data, input_labels)
self.assertTrue(isinstance(graph, ops.Operation))
def testInferenceConstructionSparse(self):
input_data = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 3],
[1, 0], [1, 7],
[2, 1],
[3, 9]],
values=[-1.0, 0.0,
-1., 2.,
1.,
-2.0],
dense_shape=[4, 10])
params = tensor_forest.ForestHParams(
num_classes=4,
num_features=10,
num_trees=10,
max_nodes=1000,
regression=True,
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
if __name__ == "__main__":
googletest.main()