Expunge torch.utils.trainer.* (#12487)

Differential Revision: D10273602

Pulled By: SsnL

fbshipit-source-id: 630c1f8ee0e366f7092d4f93dbe1efa96fc860e0
diff --git a/test/test_utils.py b/test/test_utils.py
index dff6102..cca61b5 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -14,9 +14,6 @@
 import torch.cuda
 import warnings
 from torch.utils.checkpoint import checkpoint, checkpoint_sequential
-from torch.utils.trainer import Trainer
-from torch.utils.trainer.plugins import *
-from torch.utils.trainer.plugins.plugin import Plugin
 from torch.autograd._functions.utils import prepare_onnx_paddings
 from torch.autograd._functions.utils import check_onnx_broadcast
 from common import IS_WINDOWS, IS_PPC, skipIfRocm
@@ -26,85 +23,6 @@
 from common import TestCase, run_tests, download_file
 
 
-class SimplePlugin(Plugin):
-
-    def __init__(self, interval):
-        super(SimplePlugin, self).__init__(interval)
-        self.trainer = None
-        self.num_iteration = 0
-        self.num_epoch = 0
-        self.num_batch = 0
-        self.num_update = 0
-
-    def register(self, trainer):
-        self.trainer = trainer
-
-    def iteration(self, *args):
-        self.iteration_args = args
-        self.num_iteration += 1
-
-    def epoch(self, *args):
-        self.epoch_args = args
-        self.num_epoch += 1
-
-    def batch(self, *args):
-        self.batch_args = args
-        self.num_batch += 1
-
-    def update(self, *args):
-        self.update_args = args
-        self.num_update += 1
-
-
-class ModelMock(object):
-
-    def __init__(self):
-        self.num_calls = 0
-        self.output = torch.ones(1, 1, requires_grad=True)
-
-    def __call__(self, i):
-        self.num_calls += 1
-        return self.output * 2
-
-
-class CriterionMock(object):
-
-    def __init__(self):
-        self.num_calls = 0
-
-    def __call__(self, out, target):
-        self.num_calls += 1
-        return out
-
-
-class OptimizerMock(object):
-    max_evals = 5
-    min_evals = 1
-
-    def __init__(self):
-        self.num_steps = 0
-        self.num_evals = 0
-
-    def step(self, closure):
-        for i in range(random.randint(self.min_evals, self.max_evals)):
-            loss = closure()
-            self.num_evals += 1
-        self.num_steps += 1
-
-    def zero_grad(self):
-        pass
-
-
-class DatasetMock(object):
-
-    def __iter__(self):
-        for i in range(10):
-            yield torch.randn(2, 10), torch.randperm(10)[:2]
-
-    def __len__(self):
-        return 10
-
-
 class RandomDatasetMock(object):
 
     def __getitem__(self, index):
@@ -279,84 +197,6 @@
         self.assertEqual(len(list(dataiter)), 1)
 
 
-class TestTrainer(TestCase):
-
-    intervals = [
-        [(1, 'iteration')],
-        [(1, 'epoch')],
-        [(1, 'batch')],
-        [(1, 'update')],
-        [(5, 'iteration')],
-        [(5, 'epoch')],
-        [(5, 'batch')],
-        [(5, 'update')],
-        [(1, 'iteration'), (1, 'epoch')],
-        [(5, 'update'), (1, 'iteration')],
-        [(2, 'epoch'), (1, 'batch')],
-    ]
-
-    def setUp(self):
-        self.optimizer = OptimizerMock()
-        self.trainer = Trainer(ModelMock(), CriterionMock(),
-                               self.optimizer, DatasetMock())
-        self.num_epochs = 3
-        self.dataset_size = len(self.trainer.dataset)
-        self.num_iters = self.num_epochs * self.dataset_size
-
-    def test_register_plugin(self):
-        for interval in self.intervals:
-            simple_plugin = SimplePlugin(interval)
-            self.trainer.register_plugin(simple_plugin)
-            self.assertEqual(simple_plugin.trainer, self.trainer)
-
-    def test_optimizer_step(self):
-        self.trainer.run(epochs=1)
-        self.assertEqual(self.trainer.optimizer.num_steps, 10)
-
-    def test_plugin_interval(self):
-        for interval in self.intervals:
-            self.setUp()
-            simple_plugin = SimplePlugin(interval)
-            self.trainer.register_plugin(simple_plugin)
-            self.trainer.run(epochs=self.num_epochs)
-            units = {
-                ('iteration', self.num_iters),
-                ('epoch', self.num_epochs),
-                ('batch', self.num_iters),
-                ('update', self.num_iters)
-            }
-            for unit, num_triggers in units:
-                call_every = None
-                for i, i_unit in interval:
-                    if i_unit == unit:
-                        call_every = i
-                        break
-                if call_every:
-                    expected_num_calls = math.floor(num_triggers / call_every)
-                else:
-                    expected_num_calls = 0
-                num_calls = getattr(simple_plugin, 'num_' + unit)
-                self.assertEqual(num_calls, expected_num_calls, 0)
-
-    def test_model_called(self):
-        self.trainer.run(epochs=self.num_epochs)
-        num_model_calls = self.trainer.model.num_calls
-        num_crit_calls = self.trainer.criterion.num_calls
-        self.assertEqual(num_model_calls, num_crit_calls)
-        for num_calls in [num_model_calls, num_crit_calls]:
-            lower_bound = OptimizerMock.min_evals * self.num_iters
-            upper_bound = OptimizerMock.max_evals * self.num_iters
-            self.assertEqual(num_calls, self.trainer.optimizer.num_evals)
-            self.assertLessEqual(lower_bound, num_calls)
-            self.assertLessEqual(num_calls, upper_bound)
-
-    def test_model_gradient(self):
-        self.trainer.run(epochs=self.num_epochs)
-        output_var = self.trainer.model.output
-        expected_grad = torch.ones(1, 1) * 2 * self.optimizer.num_evals
-        self.assertEqual(output_var.grad.data, expected_grad)
-
-
 test_dir = os.path.abspath(os.path.dirname(str(__file__)))
 
 
diff --git a/torch/utils/trainer/__init__.py b/torch/utils/trainer/__init__.py
deleted file mode 100644
index 29340cf..0000000
--- a/torch/utils/trainer/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-from .trainer import Trainer
diff --git a/torch/utils/trainer/plugins/__init__.py b/torch/utils/trainer/plugins/__init__.py
deleted file mode 100644
index e8d10f4..0000000
--- a/torch/utils/trainer/plugins/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .progress import ProgressMonitor
-from .accuracy import AccuracyMonitor
-from .time import TimeMonitor
-from .loss import LossMonitor
-from .logger import Logger
diff --git a/torch/utils/trainer/plugins/accuracy.py b/torch/utils/trainer/plugins/accuracy.py
deleted file mode 100644
index f6f393c..0000000
--- a/torch/utils/trainer/plugins/accuracy.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from .monitor import Monitor
-
-
-class AccuracyMonitor(Monitor):
-    stat_name = 'accuracy'
-
-    def __init__(self, *args, **kwargs):
-        kwargs.setdefault('unit', '%')
-        kwargs.setdefault('precision', 2)
-        super(AccuracyMonitor, self).__init__(*args, **kwargs)
-
-    def _get_value(self, iteration, input, target, output, loss):
-        batch_size = input.size(0)
-        predictions = output.max(1)[1].type_as(target)
-        correct = predictions.eq(target)
-        if not hasattr(correct, 'sum'):
-            correct = correct.cpu()
-        correct = correct.sum()
-        return 100. * correct / batch_size
diff --git a/torch/utils/trainer/plugins/logger.py b/torch/utils/trainer/plugins/logger.py
deleted file mode 100644
index 9bc2dfc..0000000
--- a/torch/utils/trainer/plugins/logger.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from collections import defaultdict
-from .plugin import Plugin
-
-
-class Logger(Plugin):
-    alignment = 4
-    separator = '#' * 80
-
-    def __init__(self, fields, interval=None):
-        if interval is None:
-            interval = [(1, 'iteration'), (1, 'epoch')]
-        super(Logger, self).__init__(interval)
-        self.field_widths = defaultdict(lambda: defaultdict(int))
-        self.fields = list(map(lambda f: f.split('.'), fields))
-
-    def _join_results(self, results):
-        joined_out = map(lambda i: (i[0], ' '.join(i[1])), results)
-        joined_fields = map(lambda i: '{}: {}'.format(i[0], i[1]), joined_out)
-        return '\t'.join(joined_fields)
-
-    def log(self, msg):
-        print(msg)
-
-    def register(self, trainer):
-        self.trainer = trainer
-
-    def gather_stats(self):
-        result = {}
-        return result
-
-    def _align_output(self, field_idx, output):
-        for output_idx, o in enumerate(output):
-            if len(o) < self.field_widths[field_idx][output_idx]:
-                num_spaces = self.field_widths[field_idx][output_idx] - len(o)
-                output[output_idx] += ' ' * num_spaces
-            else:
-                self.field_widths[field_idx][output_idx] = len(o)
-
-    def _gather_outputs(self, field, log_fields, stat_parent, stat, require_dict=False):
-        output = []
-        name = ''
-        if isinstance(stat, dict):
-            log_fields = stat.get(log_fields, [])
-            name = stat.get('log_name', '.'.join(field))
-            for f in log_fields:
-                output.append(f.format(**stat))
-        elif not require_dict:
-            name = '.'.join(field)
-            number_format = stat_parent.get('log_format', '')
-            unit = stat_parent.get('log_unit', '')
-            fmt = '{' + number_format + '}' + unit
-            output.append(fmt.format(stat))
-        return name, output
-
-    def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False):
-        results = []
-        for field_idx, field in enumerate(self.fields):
-            parent, stat = None, self.trainer.stats
-            for f in field:
-                parent, stat = stat, stat[f]
-            name, output = self._gather_outputs(field, log_fields,
-                                                parent, stat, require_dict)
-            if not output:
-                continue
-            self._align_output(field_idx, output)
-            results.append((name, output))
-        if not results:
-            return
-        output = self._join_results(results)
-        if prefix is not None:
-            self.log(prefix)
-        self.log(output)
-        if suffix is not None:
-            self.log(suffix)
-
-    def iteration(self, *args):
-        self._log_all('log_iter_fields')
-
-    def epoch(self, epoch_idx):
-        self._log_all('log_epoch_fields',
-                      prefix=self.separator + '\nEpoch summary:',
-                      suffix=self.separator,
-                      require_dict=True)
diff --git a/torch/utils/trainer/plugins/loss.py b/torch/utils/trainer/plugins/loss.py
deleted file mode 100644
index 1bd93f2..0000000
--- a/torch/utils/trainer/plugins/loss.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .monitor import Monitor
-
-
-class LossMonitor(Monitor):
-    stat_name = 'loss'
-
-    def _get_value(self, iteration, input, target, output, loss):
-        return loss.item()
diff --git a/torch/utils/trainer/plugins/monitor.py b/torch/utils/trainer/plugins/monitor.py
deleted file mode 100644
index b1e1d9f..0000000
--- a/torch/utils/trainer/plugins/monitor.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from .plugin import Plugin
-
-
-class Monitor(Plugin):
-
-    def __init__(self, running_average=True, epoch_average=True, smoothing=0.7,
-                 precision=None, number_format=None, unit=''):
-        if precision is None:
-            precision = 4
-        if number_format is None:
-            number_format = '.{}f'.format(precision)
-        number_format = ':' + number_format
-        super(Monitor, self).__init__([(1, 'iteration'), (1, 'epoch')])
-
-        self.smoothing = smoothing
-        self.with_running_average = running_average
-        self.with_epoch_average = epoch_average
-
-        self.log_format = number_format
-        self.log_unit = unit
-        self.log_epoch_fields = None
-        self.log_iter_fields = ['{last' + number_format + '}' + unit]
-        if self.with_running_average:
-            self.log_iter_fields += [' ({running_avg' + number_format + '}' + unit + ')']
-        if self.with_epoch_average:
-            self.log_epoch_fields = ['{epoch_mean' + number_format + '}' + unit]
-
-    def register(self, trainer):
-        self.trainer = trainer
-        stats = self.trainer.stats.setdefault(self.stat_name, {})
-        stats['log_format'] = self.log_format
-        stats['log_unit'] = self.log_unit
-        stats['log_iter_fields'] = self.log_iter_fields
-        if self.with_epoch_average:
-            stats['log_epoch_fields'] = self.log_epoch_fields
-        if self.with_epoch_average:
-            stats['epoch_stats'] = (0, 0)
-
-    def iteration(self, *args):
-        stats = self.trainer.stats.setdefault(self.stat_name, {})
-        stats['last'] = self._get_value(*args)
-
-        if self.with_epoch_average:
-            stats['epoch_stats'] = tuple(sum(t) for t in
-                                         zip(stats['epoch_stats'], (stats['last'], 1)))
-
-        if self.with_running_average:
-            previous_avg = stats.get('running_avg', 0)
-            stats['running_avg'] = previous_avg * self.smoothing + \
-                stats['last'] * (1 - self.smoothing)
-
-    def epoch(self, idx):
-        stats = self.trainer.stats.setdefault(self.stat_name, {})
-        if self.with_epoch_average:
-            epoch_stats = stats['epoch_stats']
-            stats['epoch_mean'] = epoch_stats[0] / epoch_stats[1]
-            stats['epoch_stats'] = (0, 0)
diff --git a/torch/utils/trainer/plugins/plugin.py b/torch/utils/trainer/plugins/plugin.py
deleted file mode 100644
index e1ac251..0000000
--- a/torch/utils/trainer/plugins/plugin.py
+++ /dev/null
@@ -1,10 +0,0 @@
-
-class Plugin(object):
-
-    def __init__(self, interval=None):
-        if interval is None:
-            interval = []
-        self.trigger_interval = interval
-
-    def register(self, trainer):
-        raise NotImplementedError
diff --git a/torch/utils/trainer/plugins/progress.py b/torch/utils/trainer/plugins/progress.py
deleted file mode 100644
index 5b0dc2e..0000000
--- a/torch/utils/trainer/plugins/progress.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from .plugin import Plugin
-
-
-class ProgressMonitor(Plugin):
-    stat_name = 'progress'
-
-    def __init__(self):
-        super(ProgressMonitor, self).__init__([(1, 'iteration'), (1, 'epoch')])
-
-    def register(self, trainer):
-        self.trainer = trainer
-        stats = self.trainer.stats.setdefault(self.stat_name, {})
-        stats['samples_used'] = 0
-        stats['epoch_size'] = len(trainer.dataset)
-        stats['log_iter_fields'] = [
-            '{samples_used}/{epoch_size}',
-            '({percent:.2f}%)'
-        ]
-
-    def iteration(self, iteration, input, *args):
-        stats = self.trainer.stats.setdefault(self.stat_name, {})
-        stats['samples_used'] += 1
-        stats['percent'] = 100. * stats['samples_used'] / stats['epoch_size']
-
-    def epoch(self, *args):
-        stats = self.trainer.stats.setdefault(self.stat_name, {})
-        stats['samples_used'] = 0
-        stats['percent'] = 0
diff --git a/torch/utils/trainer/plugins/time.py b/torch/utils/trainer/plugins/time.py
deleted file mode 100644
index ffdc198..0000000
--- a/torch/utils/trainer/plugins/time.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from __future__ import absolute_import
-import time
-
-from .monitor import Monitor
-
-
-class TimeMonitor(Monitor):
-    stat_name = 'time'
-
-    def __init__(self, *args, **kwargs):
-        kwargs.setdefault('unit', 'ms')
-        kwargs.setdefault('precision', 0)
-        super(TimeMonitor, self).__init__(*args, **kwargs)
-        self.last_time = None
-
-    def _get_value(self, *args):
-        if self.last_time:
-            now = time.time()
-            duration = now - self.last_time
-            self.last_time = now
-            return duration * 1000
-        else:
-            self.last_time = time.time()
-            return 0
diff --git a/torch/utils/trainer/trainer.py b/torch/utils/trainer/trainer.py
deleted file mode 100644
index 3203a75..0000000
--- a/torch/utils/trainer/trainer.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import heapq
-
-
-class Trainer(object):
-
-    def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
-        self.model = model
-        self.criterion = criterion
-        self.optimizer = optimizer
-        self.dataset = dataset
-        self.iterations = 0
-        self.stats = {}
-        self.plugin_queues = {
-            'iteration': [],
-            'epoch': [],
-            'batch': [],
-            'update': [],
-        }
-
-    def register_plugin(self, plugin):
-        plugin.register(self)
-
-        intervals = plugin.trigger_interval
-        if not isinstance(intervals, list):
-            intervals = [intervals]
-        for duration, unit in intervals:
-            queue = self.plugin_queues[unit]
-            queue.append((duration, len(queue), plugin))
-
-    def call_plugins(self, queue_name, time, *args):
-        args = (time,) + args
-        queue = self.plugin_queues[queue_name]
-        if len(queue) == 0:
-            return
-        while queue[0][0] <= time:
-            plugin = queue[0][2]
-            getattr(plugin, queue_name)(*args)
-            for trigger in plugin.trigger_interval:
-                if trigger[1] == queue_name:
-                    interval = trigger[0]
-            new_item = (time + interval, queue[0][1], plugin)
-            heapq.heappushpop(queue, new_item)
-
-    def run(self, epochs=1):
-        for q in self.plugin_queues.values():
-            heapq.heapify(q)
-
-        for i in range(1, epochs + 1):
-            self.train()
-            self.call_plugins('epoch', i)
-
-    def train(self):
-        for i, data in enumerate(self.dataset, self.iterations + 1):
-            batch_input, batch_target = data
-            self.call_plugins('batch', i, batch_input, batch_target)
-            input_var = batch_input
-            target_var = batch_target
-
-            plugin_data = [None, None]
-
-            def closure():
-                batch_output = self.model(input_var)
-                loss = self.criterion(batch_output, target_var)
-                loss.backward()
-                if plugin_data[0] is None:
-                    plugin_data[0] = batch_output.data
-                    plugin_data[1] = loss.data
-                return loss
-
-            self.optimizer.zero_grad()
-            self.optimizer.step(closure)
-            self.call_plugins('iteration', i, batch_input, batch_target,
-                              *plugin_data)
-            self.call_plugins('update', i, self.model)
-
-        self.iterations += i