| # Owner(s): ["oncall: mobile"] |
| |
| import io |
| from itertools import product |
| from pathlib import Path |
| |
| import torch |
| import torch.utils.bundled_inputs |
| |
| from torch.jit.mobile import _load_for_lite_interpreter |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| pytorch_test_dir = Path(__file__).resolve().parents[1] |
| |
| |
| class TestLiteScriptModule(TestCase): |
| def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule): |
| buffer = io.BytesIO( |
| script_module._save_to_buffer_for_lite_interpreter( |
| _save_mobile_debug_info=True |
| ) |
| ) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| return mobile_module |
| |
| def _try_fn(self, fn, *args, **kwargs): |
| try: |
| return fn(*args, **kwargs) |
| except Exception as e: |
| return e |
| |
| def test_versioned_div_tensor(self): |
| def div_tensor_0_3(self, other): |
| if self.is_floating_point() or other.is_floating_point(): |
| return self.true_divide(other) |
| return self.divide(other, rounding_mode="trunc") |
| |
| model_path = ( |
| pytorch_test_dir |
| / "cpp" |
| / "jit" |
| / "upgrader_models" |
| / "test_versioned_div_tensor_v2.ptl" |
| ) |
| mobile_module_v2 = _load_for_lite_interpreter(str(model_path)) |
| jit_module_v2 = torch.jit.load(str(model_path)) |
| current_mobile_module = self._save_load_mobile_module(jit_module_v2) |
| vals = (2.0, 3.0, 2, 3) |
| for val_a, val_b in product(vals, vals): |
| a = torch.tensor((val_a,)) |
| b = torch.tensor((val_b,)) |
| |
| def _helper(m, fn): |
| m_results = self._try_fn(m, a, b) |
| fn_result = self._try_fn(fn, a, b) |
| |
| if isinstance(m_results, Exception): |
| self.assertTrue(isinstance(fn_result, Exception)) |
| else: |
| for result in m_results: |
| print("result: ", result) |
| print("fn_result: ", fn_result) |
| print(result == fn_result) |
| self.assertTrue(result.eq(fn_result)) |
| # self.assertEqual(result, fn_result) |
| |
| # old operator should produce the same result as applying upgrader of torch.div op |
| # _helper(mobile_module_v2, div_tensor_0_3) |
| # latest operator should produce the same result as applying torch.div op |
| # _helper(current_mobile_module, torch.div) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |