| # Owner(s): ["module: unknown"] | 
 |  | 
 | import hypothesis.strategies as st | 
 | from hypothesis import given | 
 | import numpy as np | 
 | import torch | 
 | from torch.testing._internal.common_utils import TestCase | 
 | import torch.testing._internal.hypothesis_utils as hu | 
 | hu.assert_deadline_disabled() | 
 |  | 
 |  | 
 | class PruningOpTest(TestCase): | 
 |  | 
 |     # Generate rowwise mask vector based on indicator and threshold value. | 
 |     # indicator is a vector that contains one value per weight row and it | 
 |     # represents the importance of a row. | 
 |     # We mask a row if its indicator value is less than the threshold. | 
 |     def _generate_rowwise_mask(self, embedding_rows): | 
 |         indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32)) | 
 |         threshold = np.random.random_sample() | 
 |         mask = torch.BoolTensor([True if val >= threshold else False for val in indicator]) | 
 |         return mask | 
 |  | 
 |     def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype): | 
 |         embedding_weights = None | 
 |         if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: | 
 |             embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype) | 
 |         else: | 
 |             embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype) | 
 |         mask = self._generate_rowwise_mask(embedding_rows) | 
 |  | 
 |         def get_pt_result(embedding_weights, mask, indices_type): | 
 |             return torch._rowwise_prune(embedding_weights, mask, indices_type) | 
 |  | 
 |         # Reference implementation. | 
 |         def get_reference_result(embedding_weights, mask, indices_type): | 
 |             num_embeddings = mask.size()[0] | 
 |             compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type) | 
 |             pruned_weights_out = embedding_weights[mask[:]] | 
 |             idx = 0 | 
 |             for i in range(mask.size()[0]): | 
 |                 if mask[i]: | 
 |                     compressed_idx_out[i] = idx | 
 |                     idx = idx + 1 | 
 |                 else: | 
 |                     compressed_idx_out[i] = -1 | 
 |             return (pruned_weights_out, compressed_idx_out) | 
 |  | 
 |         pt_pruned_weights, pt_compressed_indices_map = get_pt_result( | 
 |             embedding_weights, mask, indices_type) | 
 |         ref_pruned_weights, ref_compressed_indices_map = get_reference_result( | 
 |             embedding_weights, mask, indices_type) | 
 |  | 
 |         torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights) | 
 |         self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map) | 
 |         self.assertEqual(pt_compressed_indices_map.dtype, indices_type) | 
 |  | 
 |  | 
 |     @given( | 
 |         embedding_rows=st.integers(1, 100), | 
 |         embedding_dims=st.integers(1, 100), | 
 |         weights_dtype=st.sampled_from([torch.float64, torch.float32, | 
 |                                        torch.float16, torch.int8, | 
 |                                        torch.int16, torch.int32, torch.int64]) | 
 |     ) | 
 |     def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype): | 
 |         self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype) | 
 |  | 
 |  | 
 |     @given( | 
 |         embedding_rows=st.integers(1, 100), | 
 |         embedding_dims=st.integers(1, 100), | 
 |         weights_dtype=st.sampled_from([torch.float64, torch.float32, | 
 |                                        torch.float16, torch.int8, | 
 |                                        torch.int16, torch.int32, torch.int64]) | 
 |     ) | 
 |     def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): | 
 |         self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype) |