| import pytest | |
| import torch | |
| from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix | |
| @pytest.mark.parametrize('shape', [ | |
| (2, 2), | |
| (3, 3), | |
| (2, 4, 4), | |
| (2, 2, 4, 4), | |
| ]) | |
| def test_tril_matrix_to_vec(shape): | |
| mat = torch.randn(shape) | |
| n = mat.shape[-1] | |
| for diag in range(-n, n): | |
| actual = mat.tril(diag) | |
| vec = tril_matrix_to_vec(actual, diag) | |
| tril_mat = vec_to_tril_matrix(vec, diag) | |
| assert torch.allclose(tril_mat, actual) | |
| if __name__ == '__main__': | |
| pytest.main([__file__]) |