| # Owner(s): ["oncall: mobile"] |
| import torch.utils.show_pickle |
| from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS |
| class TestShowPickle(TestCase): |
| @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") |
| def test_scripted_model(self): |
| class MyCoolModule(torch.nn.Module): |
| def __init__(self, weight): |
| m = torch.jit.script(MyCoolModule(torch.tensor([2.0]))) |
| with tempfile.NamedTemporaryFile() as tmp: |
| torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf) |
| self.assertRegex(output, "MyCoolModule") |
| self.assertRegex(output, "weight") |
| if __name__ == '__main__': |