| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tf.contrib.tensor_forest.ops.tensor_forest.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from google.protobuf.json_format import ParseDict |
| from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto |
| from tensorflow.contrib.tensor_forest.python import tensor_forest |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import resources |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import googletest |
| |
| |
| class TensorForestTest(test_util.TensorFlowTestCase): |
| |
| def testForestHParams(self): |
| hparams = tensor_forest.ForestHParams( |
| num_classes=2, |
| num_trees=100, |
| max_nodes=1000, |
| split_after_samples=25, |
| num_features=60).fill() |
| self.assertEquals(2, hparams.num_classes) |
| self.assertEquals(3, hparams.num_output_columns) |
| self.assertEquals(10, hparams.num_splits_to_consider) |
| # Default value of valid_leaf_threshold |
| self.assertEquals(1, hparams.valid_leaf_threshold) |
| self.assertEquals(0, hparams.base_random_seed) |
| |
| def testForestHParamsBigTree(self): |
| hparams = tensor_forest.ForestHParams( |
| num_classes=2, |
| num_trees=100, |
| max_nodes=1000000, |
| split_after_samples=25, |
| num_features=1000).fill() |
| self.assertEquals(31, hparams.num_splits_to_consider) |
| |
| def testForestHParamsStringParams(self): |
| hparams = tensor_forest.ForestHParams( |
| num_classes=2, |
| num_trees=100, |
| max_nodes=1000000, |
| split_after_samples="25", |
| num_splits_to_consider="1000000", |
| num_features=1000).fill() |
| self.assertEquals("1000000", hparams.num_splits_to_consider) |
| |
| def testTrainingConstructionClassification(self): |
| input_data = [[-1., 0.], [-1., 2.], # node 1 |
| [1., 0.], [1., -2.]] # node 2 |
| input_labels = [0, 1, 2, 3] |
| |
| params = tensor_forest.ForestHParams( |
| num_classes=4, |
| num_features=2, |
| num_trees=10, |
| max_nodes=1000, |
| split_after_samples=25).fill() |
| |
| graph_builder = tensor_forest.RandomForestGraphs(params) |
| graph = graph_builder.training_graph(input_data, input_labels) |
| self.assertTrue(isinstance(graph, ops.Operation)) |
| |
| def testTrainingConstructionRegression(self): |
| input_data = [[-1., 0.], [-1., 2.], # node 1 |
| [1., 0.], [1., -2.]] # node 2 |
| input_labels = [0, 1, 2, 3] |
| |
| params = tensor_forest.ForestHParams( |
| num_classes=4, |
| num_features=2, |
| num_trees=10, |
| max_nodes=1000, |
| split_after_samples=25, |
| regression=True).fill() |
| |
| graph_builder = tensor_forest.RandomForestGraphs(params) |
| graph = graph_builder.training_graph(input_data, input_labels) |
| self.assertTrue(isinstance(graph, ops.Operation)) |
| |
| def testInferenceConstruction(self): |
| input_data = [[-1., 0.], [-1., 2.], # node 1 |
| [1., 0.], [1., -2.]] # node 2 |
| |
| params = tensor_forest.ForestHParams( |
| num_classes=4, |
| num_features=2, |
| num_trees=10, |
| max_nodes=1000, |
| split_after_samples=25).fill() |
| |
| graph_builder = tensor_forest.RandomForestGraphs(params) |
| probs, paths, var = graph_builder.inference_graph(input_data) |
| self.assertTrue(isinstance(probs, ops.Tensor)) |
| self.assertTrue(isinstance(paths, ops.Tensor)) |
| self.assertTrue(isinstance(var, ops.Tensor)) |
| |
| def testInfrenceFromRestoredModel(self): |
| input_data = [[-1., 0.], [-1., 2.], # node 1 |
| [1., 0.], [1., -2.]] # node 2 |
| expected_prediction = [[0.0, 1.0], [0.0, 1.0], |
| [0.0, 1.0], [0.0, 1.0]] |
| hparams = tensor_forest.ForestHParams( |
| num_classes=2, |
| num_features=2, |
| num_trees=1, |
| max_nodes=1000, |
| split_after_samples=25).fill() |
| tree_weight = {'decisionTree': |
| {'nodes': |
| [{'binaryNode': |
| {'rightChildId': 2, |
| 'leftChildId': 1, |
| 'inequalityLeftChildTest': |
| {'featureId': {'id': '0'}, |
| 'threshold': {'floatValue': 0}}}}, |
| {'leaf': {'vector': |
| {'value': [{'floatValue': 0.0}, |
| {'floatValue': 1.0}]}}, |
| 'nodeId': 1}, |
| {'leaf': {'vector': |
| {'value': [{'floatValue': 0.0}, |
| {'floatValue': 1.0}]}}, |
| 'nodeId': 2}]}} |
| restored_tree_param = ParseDict(tree_weight, |
| _tree_proto.Model()).SerializeToString() |
| graph_builder = tensor_forest.RandomForestGraphs(hparams, |
| [restored_tree_param]) |
| probs, paths, var = graph_builder.inference_graph(input_data) |
| self.assertTrue(isinstance(probs, ops.Tensor)) |
| self.assertTrue(isinstance(paths, ops.Tensor)) |
| self.assertTrue(isinstance(var, ops.Tensor)) |
| with self.test_session(): |
| variables.global_variables_initializer().run() |
| resources.initialize_resources(resources.shared_resources()).run() |
| self.assertEquals(probs.eval().shape, (4, 2)) |
| self.assertEquals(probs.eval().tolist(), expected_prediction) |
| |
| def testTrainingConstructionClassificationSparse(self): |
| input_data = sparse_tensor.SparseTensor( |
| indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]], |
| values=[-1.0, 0.0, -1., 2., 1., -2.0], |
| dense_shape=[4, 10]) |
| input_labels = [0, 1, 2, 3] |
| |
| params = tensor_forest.ForestHParams( |
| num_classes=4, |
| num_features=10, |
| num_trees=10, |
| max_nodes=1000, |
| split_after_samples=25).fill() |
| |
| graph_builder = tensor_forest.RandomForestGraphs(params) |
| graph = graph_builder.training_graph(input_data, input_labels) |
| self.assertTrue(isinstance(graph, ops.Operation)) |
| |
| def testInferenceConstructionSparse(self): |
| input_data = sparse_tensor.SparseTensor( |
| indices=[[0, 0], [0, 3], |
| [1, 0], [1, 7], |
| [2, 1], |
| [3, 9]], |
| values=[-1.0, 0.0, |
| -1., 2., |
| 1., |
| -2.0], |
| dense_shape=[4, 10]) |
| |
| params = tensor_forest.ForestHParams( |
| num_classes=4, |
| num_features=10, |
| num_trees=10, |
| max_nodes=1000, |
| regression=True, |
| split_after_samples=25).fill() |
| |
| graph_builder = tensor_forest.RandomForestGraphs(params) |
| probs, paths, var = graph_builder.inference_graph(input_data) |
| self.assertTrue(isinstance(probs, ops.Tensor)) |
| self.assertTrue(isinstance(paths, ops.Tensor)) |
| self.assertTrue(isinstance(var, ops.Tensor)) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |