blob: c17e8d8e99f90dc07f83c7f22205357967f802da [file] [log] [blame]
#!/usr/bin/env python3
import io
import textwrap
from typing import List
import torch
import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import TestCase, run_tests
def model_size(sm):
buffer = io.BytesIO()
torch.jit.save(sm, buffer)
return len(buffer.getvalue())
def save_and_load(sm):
buffer = io.BytesIO()
torch.jit.save(sm, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
class TestBundledInputs(TestCase):
def test_single_tensors(self):
class SingleTensorModel(torch.nn.Module):
def forward(self, arg):
return arg
sm = torch.jit.script(SingleTensorModel())
original_size = model_size(sm)
get_expr : List[str] = []
samples = [
# Tensor with small numel and small storage.
(torch.tensor([1]),),
# Tensor with large numel and small storage.
(torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
# Tensor with small numel and large storage.
(torch.tensor(range(1 << 16))[-8:],),
# Large zero tensor.
(torch.zeros(1 << 16),),
# Large channels-last ones tensor.
(torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
# Special encoding of random tensor.
(torch.utils.bundled_inputs.bundle_randn(1 << 16),),
]
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, samples, get_expr)
# print(get_expr[0])
# print(sm._generate_bundled_inputs.code)
# Make sure the model only grew a little bit,
# despite having nominally large bundled inputs.
augmented_size = model_size(sm)
self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples))
self.assertTrue(loaded.run_on_bundled_input(0) is inflated[0][0])
for idx, inp in enumerate(inflated):
self.assertIsInstance(inp, tuple)
self.assertEqual(len(inp), 1)
self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5:
# Strides might be important for benchmarking.
self.assertEqual(inp[0].stride(), samples[idx][0].stride())
self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
# This tensor is random, but with 100,000 trials,
# mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
self.assertEqual(inflated[5][0].shape, (1 << 16,))
self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
def test_large_tensor_with_inflation(self):
class SingleTensorModel(torch.nn.Module):
def forward(self, arg):
return arg
sm = torch.jit.script(SingleTensorModel())
sample_tensor = torch.randn(1 << 16)
# We can store tensors with custom inflation functions regardless
# of size, even if inflation is just the identity.
sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, [(sample,)])
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(len(inflated), 1)
self.assertEqual(inflated[0][0], sample_tensor)
def test_rejected_tensors(self):
def check_tensor(sample):
# Need to define the class in this scope to get a fresh type for each run.
class SingleTensorModel(torch.nn.Module):
def forward(self, arg):
return arg
sm = torch.jit.script(SingleTensorModel())
with self.assertRaisesRegex(Exception, "Bundled input argument"):
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, [(sample,)])
# Plain old big tensor.
check_tensor(torch.randn(1 << 16))
# This tensor has two elements, but they're far apart in memory.
# We currently cannot represent this compactly while preserving
# the strides.
small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
self.assertEqual(small_sparse.numel(), 2)
check_tensor(small_sparse)
def test_non_tensors(self):
class StringAndIntModel(torch.nn.Module):
def forward(self, fmt: str, num: int):
return fmt.format(num)
sm = torch.jit.script(StringAndIntModel())
samples = [
("first {}", 1),
("second {}", 2),
]
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, samples)
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples)
self.assertTrue(loaded.run_on_bundled_input(0) == "first 1")
def test_multiple_methods_with_inputs(self):
class MultipleMethodModel(torch.nn.Module):
def forward(self, arg):
return arg
@torch.jit.export
def foo(self, arg):
return arg
mm = torch.jit.script(MultipleMethodModel())
samples = [
# Tensor with small numel and small storage.
(torch.tensor([1]),),
# Tensor with large numel and small storage.
(torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
# Tensor with small numel and large storage.
(torch.tensor(range(1 << 16))[-8:],),
# Large zero tensor.
(torch.zeros(1 << 16),),
# Large channels-last ones tensor.
(torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
]
info = [
'Tensor with small numel and small storage.',
'Tensor with large numel and small storage.',
'Tensor with small numel and large storage.',
'Large zero tensor.',
'Large channels-last ones tensor.',
'Special encoding of random tensor.',
]
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
mm,
inputs={
mm.forward : samples,
mm.foo : samples
},
info={
mm.forward : info,
mm.foo : info
}
)
loaded = save_and_load(mm)
inflated = loaded.get_all_bundled_inputs()
# Make sure these functions are all consistent.
self.assertEqual(inflated, samples)
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
# Check running and size helpers
self.assertTrue(loaded.run_on_bundled_input(0) is inflated[0][0])
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
# Check helper that work on all functions
all_info = loaded.get_bundled_inputs_functions_and_info()
self.assertEqual(set(all_info.keys()), set(['forward', 'foo']))
self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward'])
self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo'])
self.assertEqual(all_info['forward']['info'], info)
self.assertEqual(all_info['foo']['info'], info)
# example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
for func_name in all_info.keys():
input_func_name = all_info[func_name]['get_inputs_function_name'][0]
func_to_run = getattr(loaded, input_func_name)
self.assertEqual(func_to_run(), samples)
def test_multiple_methods_with_inputs_failures(self):
class MultipleMethodModel(torch.nn.Module):
def forward(self, arg):
return arg
@torch.jit.export
def foo(self, arg):
return arg
samples = [(torch.tensor([1]),)]
# Test Failure Case both defined
try:
nn = torch.jit.script(MultipleMethodModel())
definition = textwrap.dedent("""
def _generate_bundled_inputs_for_forward(self):
return []
""")
nn.define(definition)
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
nn,
inputs={
nn.forward : samples,
nn.foo : samples,
},
)
except Exception:
pass
else:
self.fail(
"Expected augment_many_model_functions_with_bundled_inputs to throw due to _generate_bundled_inputs_for_forward"
+ " and inputs being defined, got no exception instead."
)
# Test Failure Case neither defined
try:
mm = torch.jit.script(MultipleMethodModel())
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
mm,
inputs={
mm.forward : None,
mm.foo : samples,
},
)
except Exception:
pass
else:
self.fail(
"Expected augment_many_model_functions_with_bundled_inputs to throw due to not having bundled inputs for forward,"
+ " got no exception instead."
)
if __name__ == '__main__':
run_tests()