| #!/usr/bin/env python3 |
| # Owner(s): ["oncall: r2p"] |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree.abs |
| import abc |
| import unittest.mock as mock |
| |
| from torch.distributed.elastic.metrics.api import ( |
| _get_metric_name, |
| MetricData, |
| MetricHandler, |
| MetricStream, |
| prof, |
| ) |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| def foo_1(): |
| pass |
| |
| |
| class TestMetricsHandler(MetricHandler): |
| def __init__(self): |
| self.metric_data = {} |
| |
| def emit(self, metric_data: MetricData): |
| self.metric_data[metric_data.name] = metric_data |
| |
| |
| class Parent(abc.ABC): |
| @abc.abstractmethod |
| def func(self): |
| raise NotImplementedError |
| |
| def base_func(self): |
| self.func() |
| |
| |
| class Child(Parent): |
| # need to decorate the implementation not the abstract method! |
| @prof |
| def func(self): |
| pass |
| |
| |
| class MetricsApiTest(TestCase): |
| def foo_2(self): |
| pass |
| |
| @prof |
| def bar(self): |
| pass |
| |
| @prof |
| def throw(self): |
| raise RuntimeError |
| |
| @prof(group="torchelastic") |
| def bar2(self): |
| pass |
| |
| def test_get_metric_name(self): |
| # Note: since pytorch uses main method to launch tests, |
| # the module will be different between fb and oss, this |
| # allows keeping the module name consistent. |
| foo_1.__module__ = "api_test" |
| self.assertEqual("api_test.foo_1", _get_metric_name(foo_1)) |
| self.assertEqual("MetricsApiTest.foo_2", _get_metric_name(self.foo_2)) |
| |
| def test_profile(self): |
| handler = TestMetricsHandler() |
| stream = MetricStream("torchelastic", handler) |
| # patch instead of configure to avoid conflicts when running tests in parallel |
| with mock.patch( |
| "torch.distributed.elastic.metrics.api.getStream", return_value=stream |
| ): |
| self.bar() |
| |
| self.assertEqual(1, handler.metric_data["MetricsApiTest.bar.success"].value) |
| self.assertNotIn("MetricsApiTest.bar.failure", handler.metric_data) |
| self.assertIn("MetricsApiTest.bar.duration.ms", handler.metric_data) |
| |
| with self.assertRaises(RuntimeError): |
| self.throw() |
| |
| self.assertEqual( |
| 1, handler.metric_data["MetricsApiTest.throw.failure"].value |
| ) |
| self.assertNotIn("MetricsApiTest.bar_raise.success", handler.metric_data) |
| self.assertIn("MetricsApiTest.throw.duration.ms", handler.metric_data) |
| |
| self.bar2() |
| self.assertEqual( |
| "torchelastic", |
| handler.metric_data["MetricsApiTest.bar2.success"].group_name, |
| ) |
| |
| def test_inheritance(self): |
| handler = TestMetricsHandler() |
| stream = MetricStream("torchelastic", handler) |
| # patch instead of configure to avoid conflicts when running tests in parallel |
| with mock.patch( |
| "torch.distributed.elastic.metrics.api.getStream", return_value=stream |
| ): |
| c = Child() |
| c.base_func() |
| |
| self.assertEqual(1, handler.metric_data["Child.func.success"].value) |
| self.assertIn("Child.func.duration.ms", handler.metric_data) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |