blob: 2b6f8fe2453f971ce1f80a9cb77b179d17c926ca [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import dataclasses
import unittest.mock
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
try:
from transformers import modeling_outputs
from transformers.configuration_utils import PretrainedConfig
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
)
except ImportError:
modeling_outputs = None
def maybe_skip(fn):
if modeling_outputs is None:
return unittest.skip("requires HuggingFace")(fn)
return fn
class TestHFPretrained(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_pretrained(self):
def fn(a, tmp):
if hasattr(tmp, "somekey"):
a = a + 1
if tmp.return_dict:
return a + torch.ones(2) * tmp.max_length
return a
x = torch.randn(2)
tmp = PretrainedConfig(return_dict=True, max_length=20)
ref = fn(x, tmp)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(x, tmp)
self.assertTrue(same(ref, res))
class TestModelOutput(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_mo_create(self):
def fn(a, b):
tmp = BaseModelOutput(a + 1, attentions=b + 3)
return tmp
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
@maybe_skip
def test_mo_assign(self):
def fn(a, b):
tmp = BaseModelOutput(last_hidden_state=b + 3)
tmp.hidden_states = a + 7
tmp["attentions"] = a + b + 6
return tmp
args = [torch.randn(10), torch.randn(10)]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
self.assertTrue(same(obj1.attentions, obj2.attentions))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 4)
def _common(self, fn, op_count):
args = [
BaseModelOutput(
last_hidden_state=torch.randn(10), attentions=torch.randn(10)
)
]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1, obj2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, op_count)
@maybe_skip
def test_mo_getattr(self):
def fn(obj: BaseModelOutput):
x = obj.last_hidden_state * 10
if obj.hidden_states is not None:
x += obj.hidden_states
if obj.attentions is not None:
x += obj.attentions
return x
self._common(fn, 2)
@maybe_skip
def test_mo_getitem(self):
def fn(obj: BaseModelOutput):
x = obj["last_hidden_state"] * 10
if "hidden_stats" in obj:
x += obj["hidden_states"]
if "attentions" in obj:
x += obj["attentions"]
return x
self._common(fn, 2)
@maybe_skip
def test_mo_tuple(self):
def fn(obj: BaseModelOutput):
a, b = obj.to_tuple()
return a + b * 10
self._common(fn, 2)
@maybe_skip
def test_mo_index(self):
def fn(obj: BaseModelOutput):
return obj[0] * 10 + obj[1]
self._common(fn, 2)
@maybe_skip
def test_mo_init(self):
@dataclasses.dataclass
class MyDataClass(ModelOutput):
a: torch.Tensor
b: torch.Tensor = None
c: torch.Tensor = None
d: torch.Tensor = None
e: torch.Tensor = None
def fn(obj):
class_fields = dataclasses.fields(obj)
assert len(class_fields)
assert all(field.default is None for field in class_fields[1:])
other_fields_are_none = all(
getattr(obj, field.name) is None for field in class_fields[1:]
)
assert not other_fields_are_none
total = getattr(obj, class_fields[0].name)
for field in class_fields[1:]:
v = getattr(obj, field.name)
if v is not None:
total += v
return total
tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
obj1 = MyDataClass(*tensors)
correct1 = fn(obj1)
obj2 = MyDataClass(*tensors)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertTrue(same(opt_fn(obj2), correct1))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
@maybe_skip
def test_HF_bert_model_output(self):
class BertPooler(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense = torch.nn.Linear(768, 768).to("cuda")
self.activation = torch.nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
hidden_states: torch.Tensor,
) -> BaseModelOutputWithPastAndCrossAttentions:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=None,
attentions=None,
cross_attentions=None,
)
class BertModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = BertEncoder()
self.pooler = BertPooler()
def forward(
self,
sequence_output: torch.Tensor,
) -> BaseModelOutputWithPoolingAndCrossAttentions:
encoder_outputs = self.encoder(sequence_output)
# test __getitem__ and to_tuple
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
# test CustomDictVariable.create
result = BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# test __setattr__
result.pooler_output = pooled_output
# test __setitem__
result["pooler_output"] = pooled_output
return result
sequence_output = torch.rand(1, 12, 768).to("cuda")
model = BertModel()
orig_result = model(sequence_output)
compiled_model = torch.compile(model, backend="eager")
compiled_result = compiled_model(sequence_output)
self.assertTrue(
torch.allclose(
orig_result.last_hidden_state, compiled_result.last_hidden_state
)
)
self.assertTrue(
torch.allclose(orig_result.pooler_output, compiled_result.pooler_output)
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()