Enable flatbuffer tests properly. (#98363)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98363
Approved by: https://github.com/angelayi
diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py
index 8902209..d5b5a86 100644
--- a/test/jit/test_save_load.py
+++ b/test/jit/test_save_load.py
@@ -4,7 +4,6 @@
import os
import pathlib
import sys
-import unittest
from typing import NamedTuple, Optional
import torch
@@ -16,7 +15,6 @@
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
-ENABLE_FLATBUFFER = os.environ.get("ENABLE_FLATBUFFER", "0") == "1"
if __name__ == "__main__":
raise RuntimeError(
@@ -675,10 +673,6 @@
return module_buffer
-# TODO(qihqi): This param is not setup correctly.
-@unittest.skipIf(
- not ENABLE_FLATBUFFER, "Need to enable flatbuffer to run the below tests"
-)
class TestSaveLoadFlatbuffer(JitTestCase):
def test_different_modules(self):
"""
@@ -1052,15 +1046,14 @@
torch.jit.save_jit_module_to_flatbuffer(
first_script_module, first_saved_module)
first_saved_module.seek(0)
- expected = {
- 'bytecode_version': 4,
- 'operator_version': 4,
- 'function_names': {'__torch__.___torch_mangle_0.Foo.forward'},
- 'type_names': set(),
- 'opname_to_num_args': {'aten::linear': 3}}
- self.assertEqual(
- torch.jit._serialization.get_flatbuffer_module_info(first_saved_module),
- expected)
+ ff_info = torch.jit._serialization.get_flatbuffer_module_info(first_saved_module)
+ self.assertEqual(ff_info['bytecode_version'], 9)
+ self.assertEqual(ff_info['operator_version'], 1)
+ self.assertEqual(ff_info['type_names'], set())
+ self.assertEqual(ff_info['opname_to_num_args'], {'aten::linear': 3})
+
+ self.assertEqual(len(ff_info['function_names']), 1)
+ self.assertTrue(next(iter(ff_info['function_names'])).endswith('forward'))
def test_save_load_params_buffers_submodules(self):
@@ -1124,12 +1117,11 @@
module = Module()
script_module = torch.jit.script(module)
- script_module_io = io.BytesIO()
- extra_files = {"abc.json": "[1,2,3]"}
- script_module._save_for_lite_interpreter(script_module_io, _extra_files=extra_files, _use_flatbuffer=True)
- script_module_io.seek(0)
+ extra_files = {"abc.json": b"[1,2,3]"}
+ script_module_io = script_module._save_to_buffer_for_lite_interpreter(
+ _extra_files=extra_files, _use_flatbuffer=True)
re_extra_files = {}
- torch._C._get_model_extra_files_from_buffer(script_module_io, _extra_files=re_extra_files)
+ torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)
self.assertEqual(extra_files, re_extra_files)