blob: c041f561f679d911e07bd883dce811c1f70552af [file] [log] [blame]
from torch.testing._internal.common_utils import (
TestCase, run_tests,
)
from datetime import timedelta, datetime
import time
from torch.monitor import (
Aggregation,
FixedCountStat,
IntervalStat,
Event,
log_event,
register_event_handler,
unregister_event_handler,
)
class TestMonitor(TestCase):
def test_interval_stat(self) -> None:
events = []
def handler(event):
events.append(event)
handle = register_event_handler(handler)
s = IntervalStat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
timedelta(milliseconds=1),
)
s.add(2)
time.sleep(0.002)
s.add(3)
self.assertEqual(s.name, "asdf")
self.assertGreaterEqual(len(events), 1)
unregister_event_handler(handle)
def test_fixed_count_stat(self) -> None:
s = FixedCountStat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
3,
)
s.add(1)
s.add(2)
name = s.name
self.assertEqual(name, "asdf")
self.assertEqual(s.count, 2)
s.add(3)
self.assertEqual(s.count, 0)
self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
def test_log_event(self) -> None:
e = Event(
name="torch.monitor.TestEvent",
timestamp=datetime.now(),
data={
"str": "a string",
"float": 1234.0,
"int": 1234,
},
)
self.assertEqual(e.name, "torch.monitor.TestEvent")
self.assertIsNotNone(e.timestamp)
self.assertIsNotNone(e.data)
log_event(e)
def test_event_handler(self) -> None:
events = []
def handler(event: Event) -> None:
events.append(event)
handle = register_event_handler(handler)
e = Event(
name="torch.monitor.TestEvent",
timestamp=datetime.now(),
data={},
)
log_event(e)
self.assertEqual(len(events), 1)
self.assertEqual(events[0], e)
log_event(e)
self.assertEqual(len(events), 2)
unregister_event_handler(handle)
log_event(e)
self.assertEqual(len(events), 2)
if __name__ == '__main__':
run_tests()