|  | # Owner(s): ["module: dynamo"] | 
|  | import dataclasses | 
|  | import unittest.mock | 
|  |  | 
|  | import torch | 
|  |  | 
|  | import torch._dynamo.test_case | 
|  | import torch._dynamo.testing | 
|  | from torch._dynamo.testing import same | 
|  |  | 
|  | try: | 
|  | from transformers import modeling_outputs | 
|  | from transformers.configuration_utils import PretrainedConfig | 
|  | from transformers.file_utils import ModelOutput | 
|  | from transformers.modeling_outputs import BaseModelOutput | 
|  | except ImportError: | 
|  | modeling_outputs = None | 
|  |  | 
|  |  | 
|  | def maybe_skip(fn): | 
|  | if modeling_outputs is None: | 
|  | return unittest.skip("requires HuggingFace")(fn) | 
|  | return fn | 
|  |  | 
|  |  | 
|  | class TestHFPretrained(torch._dynamo.test_case.TestCase): | 
|  | @maybe_skip | 
|  | def test_pretrained(self): | 
|  | def fn(a, tmp): | 
|  | if hasattr(tmp, "somekey"): | 
|  | a = a + 1 | 
|  | if tmp.return_dict: | 
|  | return a + torch.ones(2) * tmp.max_length | 
|  | return a | 
|  |  | 
|  | x = torch.randn(2) | 
|  | tmp = PretrainedConfig(return_dict=True, max_length=20) | 
|  | ref = fn(x, tmp) | 
|  | opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | 
|  | res = opt_fn(x, tmp) | 
|  | self.assertTrue(same(ref, res)) | 
|  |  | 
|  |  | 
|  | class TestModelOutput(torch._dynamo.test_case.TestCase): | 
|  | @maybe_skip | 
|  | def test_mo_create(self): | 
|  | def fn(a, b): | 
|  | tmp = BaseModelOutput(a + 1, attentions=b + 3) | 
|  | return tmp | 
|  |  | 
|  | torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2) | 
|  |  | 
|  | @maybe_skip | 
|  | def test_mo_assign(self): | 
|  | def fn(a, b): | 
|  | tmp = BaseModelOutput(last_hidden_state=b + 3) | 
|  | tmp.hidden_states = a + 7 | 
|  | tmp["attentions"] = a + b + 6 | 
|  | return tmp | 
|  |  | 
|  | args = [torch.randn(10), torch.randn(10)] | 
|  | obj1 = fn(*args) | 
|  |  | 
|  | cnts = torch._dynamo.testing.CompileCounter() | 
|  | opt_fn = torch._dynamo.optimize_assert(cnts)(fn) | 
|  | obj2 = opt_fn(*args) | 
|  | self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state)) | 
|  | self.assertTrue(same(obj1.hidden_states, obj2.hidden_states)) | 
|  | self.assertTrue(same(obj1.attentions, obj2.attentions)) | 
|  | self.assertEqual(cnts.frame_count, 1) | 
|  | self.assertEqual(cnts.op_count, 4) | 
|  |  | 
|  | def _common(self, fn, op_count): | 
|  | args = [ | 
|  | BaseModelOutput( | 
|  | last_hidden_state=torch.randn(10), attentions=torch.randn(10) | 
|  | ) | 
|  | ] | 
|  | obj1 = fn(*args) | 
|  | cnts = torch._dynamo.testing.CompileCounter() | 
|  | opt_fn = torch._dynamo.optimize_assert(cnts)(fn) | 
|  | obj2 = opt_fn(*args) | 
|  | self.assertTrue(same(obj1, obj2)) | 
|  | self.assertEqual(cnts.frame_count, 1) | 
|  | self.assertEqual(cnts.op_count, op_count) | 
|  |  | 
|  | @maybe_skip | 
|  | def test_mo_getattr(self): | 
|  | def fn(obj: BaseModelOutput): | 
|  | x = obj.last_hidden_state * 10 | 
|  | if obj.hidden_states is not None: | 
|  | x += obj.hidden_states | 
|  | if obj.attentions is not None: | 
|  | x += obj.attentions | 
|  | return x | 
|  |  | 
|  | self._common(fn, 2) | 
|  |  | 
|  | @maybe_skip | 
|  | def test_mo_getitem(self): | 
|  | def fn(obj: BaseModelOutput): | 
|  | x = obj["last_hidden_state"] * 10 | 
|  | if "hidden_stats" in obj: | 
|  | x += obj["hidden_states"] | 
|  | if "attentions" in obj: | 
|  | x += obj["attentions"] | 
|  | return x | 
|  |  | 
|  | self._common(fn, 2) | 
|  |  | 
|  | @maybe_skip | 
|  | def test_mo_tuple(self): | 
|  | def fn(obj: BaseModelOutput): | 
|  | a, b = obj.to_tuple() | 
|  | return a + b * 10 | 
|  |  | 
|  | self._common(fn, 2) | 
|  |  | 
|  | @maybe_skip | 
|  | def test_mo_index(self): | 
|  | def fn(obj: BaseModelOutput): | 
|  | return obj[0] * 10 + obj[1] | 
|  |  | 
|  | self._common(fn, 2) | 
|  |  | 
|  | @maybe_skip | 
|  | def test_mo_init(self): | 
|  | @dataclasses.dataclass | 
|  | class MyDataClass(ModelOutput): | 
|  | a: torch.Tensor | 
|  | b: torch.Tensor = None | 
|  | c: torch.Tensor = None | 
|  | d: torch.Tensor = None | 
|  | e: torch.Tensor = None | 
|  |  | 
|  | def fn(obj): | 
|  | class_fields = dataclasses.fields(obj) | 
|  | assert len(class_fields) | 
|  | assert all(field.default is None for field in class_fields[1:]) | 
|  | other_fields_are_none = all( | 
|  | getattr(obj, field.name) is None for field in class_fields[1:] | 
|  | ) | 
|  | assert not other_fields_are_none | 
|  |  | 
|  | total = getattr(obj, class_fields[0].name) | 
|  | for field in class_fields[1:]: | 
|  | v = getattr(obj, field.name) | 
|  | if v is not None: | 
|  | total += v | 
|  |  | 
|  | return total | 
|  |  | 
|  | tensors = [torch.randn(10), torch.randn(10), torch.randn(10)] | 
|  | obj1 = MyDataClass(*tensors) | 
|  | correct1 = fn(obj1) | 
|  |  | 
|  | obj2 = MyDataClass(*tensors) | 
|  | cnts = torch._dynamo.testing.CompileCounter() | 
|  | opt_fn = torch._dynamo.optimize(cnts)(fn) | 
|  | self.assertTrue(same(opt_fn(obj2), correct1)) | 
|  | self.assertEqual(cnts.frame_count, 1) | 
|  | self.assertEqual(cnts.op_count, 2) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | from torch._dynamo.test_case import run_tests | 
|  |  | 
|  | run_tests() |