|  | # Owner(s): ["module: unknown"] | 
|  |  | 
|  | import glob | 
|  | import io | 
|  | import os | 
|  | import unittest | 
|  |  | 
|  | import torch | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests | 
|  |  | 
|  |  | 
|  | try: | 
|  | from third_party.build_bundled import create_bundled | 
|  | except ImportError: | 
|  | create_bundled = None | 
|  |  | 
|  | license_file = 'third_party/LICENSES_BUNDLED.txt' | 
|  | starting_txt = 'The Pytorch repository and source distributions bundle' | 
|  | site_packages = os.path.dirname(os.path.dirname(torch.__file__)) | 
|  | distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info')) | 
|  |  | 
|  | class TestLicense(TestCase): | 
|  |  | 
|  | @unittest.skipIf(not create_bundled, "can only be run in a source tree") | 
|  | def test_license_for_wheel(self): | 
|  | current = io.StringIO() | 
|  | create_bundled('third_party', current) | 
|  | with open(license_file) as fid: | 
|  | src_tree = fid.read() | 
|  | if not src_tree == current.getvalue(): | 
|  | raise AssertionError( | 
|  | f'the contents of "{license_file}" do not ' | 
|  | 'match the current state of the third_party files. Use ' | 
|  | '"python third_party/build_bundled.py" to regenerate it') | 
|  |  | 
|  | @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test") | 
|  | def test_distinfo_license(self): | 
|  | """If run when pytorch is installed via a wheel, the license will be in | 
|  | site-package/torch-*dist-info/LICENSE. Make sure it contains the third | 
|  | party bundle of licenses""" | 
|  |  | 
|  | if len(distinfo) > 1: | 
|  | raise AssertionError('Found too many "torch-*dist-info" directories ' | 
|  | f'in "{site_packages}, expected only one') | 
|  | with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid: | 
|  | txt = fid.read() | 
|  | self.assertTrue(starting_txt in txt) | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |