| # Owner(s): ["oncall: mobile"] |
| |
| import torch |
| import torch.utils.bundled_inputs |
| import io |
| |
| from torch.jit.mobile import _load_for_lite_interpreter |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| from pathlib import Path |
| from itertools import product |
| |
| pytorch_test_dir = Path(__file__).resolve().parents[1] |
| |
| class TestLiteScriptModule(TestCase): |
| |
| def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule): |
| buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| return mobile_module |
| |
| def _try_fn(self, fn, *args, **kwargs): |
| try: |
| return fn(*args, **kwargs) |
| except Exception as e: |
| return e |
| |
| def test_versioned_div_tensor(self): |
| |
| def div_tensor_0_3(self, other): |
| if self.is_floating_point() or other.is_floating_point(): |
| return self.true_divide(other) |
| return self.divide(other, rounding_mode='trunc') |
| |
| model_path = pytorch_test_dir / "cpp" / "jit" / "upgrader_models" / "test_versioned_div_tensor_v2.ptl" |
| mobile_module_v2 = _load_for_lite_interpreter(str(model_path)) |
| jit_module_v2 = torch.jit.load(str(model_path)) |
| current_mobile_module = self._save_load_mobile_module(jit_module_v2) |
| vals = (2., 3., 2, 3) |
| for val_a, val_b in product(vals, vals): |
| a = torch.tensor((val_a,)) |
| b = torch.tensor((val_b,)) |
| |
| def _helper(m, fn): |
| m_results = self._try_fn(m, a, b) |
| fn_result = self._try_fn(fn, a, b) |
| |
| if isinstance(m_results, Exception): |
| self.assertTrue(isinstance(fn_result, Exception)) |
| else: |
| for result in m_results: |
| print("result: ", result) |
| print("fn_result: ", fn_result) |
| print(result == fn_result) |
| self.assertTrue(result.eq(fn_result)) |
| # self.assertEqual(result, fn_result) |
| |
| # old operator should produce the same result as applying upgrader of torch.div op |
| # _helper(mobile_module_v2, div_tensor_0_3) |
| # latest operator should produce the same result as applying torch.div op |
| # _helper(current_mobile_module, torch.div) |
| |
| if __name__ == '__main__': |
| run_tests() |