Expunge torch.utils.trainer.* (#12487)
Differential Revision: D10273602
Pulled By: SsnL
fbshipit-source-id: 630c1f8ee0e366f7092d4f93dbe1efa96fc860e0
diff --git a/test/test_utils.py b/test/test_utils.py
index dff6102..cca61b5 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -14,9 +14,6 @@
import torch.cuda
import warnings
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
-from torch.utils.trainer import Trainer
-from torch.utils.trainer.plugins import *
-from torch.utils.trainer.plugins.plugin import Plugin
from torch.autograd._functions.utils import prepare_onnx_paddings
from torch.autograd._functions.utils import check_onnx_broadcast
from common import IS_WINDOWS, IS_PPC, skipIfRocm
@@ -26,85 +23,6 @@
from common import TestCase, run_tests, download_file
-class SimplePlugin(Plugin):
-
- def __init__(self, interval):
- super(SimplePlugin, self).__init__(interval)
- self.trainer = None
- self.num_iteration = 0
- self.num_epoch = 0
- self.num_batch = 0
- self.num_update = 0
-
- def register(self, trainer):
- self.trainer = trainer
-
- def iteration(self, *args):
- self.iteration_args = args
- self.num_iteration += 1
-
- def epoch(self, *args):
- self.epoch_args = args
- self.num_epoch += 1
-
- def batch(self, *args):
- self.batch_args = args
- self.num_batch += 1
-
- def update(self, *args):
- self.update_args = args
- self.num_update += 1
-
-
-class ModelMock(object):
-
- def __init__(self):
- self.num_calls = 0
- self.output = torch.ones(1, 1, requires_grad=True)
-
- def __call__(self, i):
- self.num_calls += 1
- return self.output * 2
-
-
-class CriterionMock(object):
-
- def __init__(self):
- self.num_calls = 0
-
- def __call__(self, out, target):
- self.num_calls += 1
- return out
-
-
-class OptimizerMock(object):
- max_evals = 5
- min_evals = 1
-
- def __init__(self):
- self.num_steps = 0
- self.num_evals = 0
-
- def step(self, closure):
- for i in range(random.randint(self.min_evals, self.max_evals)):
- loss = closure()
- self.num_evals += 1
- self.num_steps += 1
-
- def zero_grad(self):
- pass
-
-
-class DatasetMock(object):
-
- def __iter__(self):
- for i in range(10):
- yield torch.randn(2, 10), torch.randperm(10)[:2]
-
- def __len__(self):
- return 10
-
-
class RandomDatasetMock(object):
def __getitem__(self, index):
@@ -279,84 +197,6 @@
self.assertEqual(len(list(dataiter)), 1)
-class TestTrainer(TestCase):
-
- intervals = [
- [(1, 'iteration')],
- [(1, 'epoch')],
- [(1, 'batch')],
- [(1, 'update')],
- [(5, 'iteration')],
- [(5, 'epoch')],
- [(5, 'batch')],
- [(5, 'update')],
- [(1, 'iteration'), (1, 'epoch')],
- [(5, 'update'), (1, 'iteration')],
- [(2, 'epoch'), (1, 'batch')],
- ]
-
- def setUp(self):
- self.optimizer = OptimizerMock()
- self.trainer = Trainer(ModelMock(), CriterionMock(),
- self.optimizer, DatasetMock())
- self.num_epochs = 3
- self.dataset_size = len(self.trainer.dataset)
- self.num_iters = self.num_epochs * self.dataset_size
-
- def test_register_plugin(self):
- for interval in self.intervals:
- simple_plugin = SimplePlugin(interval)
- self.trainer.register_plugin(simple_plugin)
- self.assertEqual(simple_plugin.trainer, self.trainer)
-
- def test_optimizer_step(self):
- self.trainer.run(epochs=1)
- self.assertEqual(self.trainer.optimizer.num_steps, 10)
-
- def test_plugin_interval(self):
- for interval in self.intervals:
- self.setUp()
- simple_plugin = SimplePlugin(interval)
- self.trainer.register_plugin(simple_plugin)
- self.trainer.run(epochs=self.num_epochs)
- units = {
- ('iteration', self.num_iters),
- ('epoch', self.num_epochs),
- ('batch', self.num_iters),
- ('update', self.num_iters)
- }
- for unit, num_triggers in units:
- call_every = None
- for i, i_unit in interval:
- if i_unit == unit:
- call_every = i
- break
- if call_every:
- expected_num_calls = math.floor(num_triggers / call_every)
- else:
- expected_num_calls = 0
- num_calls = getattr(simple_plugin, 'num_' + unit)
- self.assertEqual(num_calls, expected_num_calls, 0)
-
- def test_model_called(self):
- self.trainer.run(epochs=self.num_epochs)
- num_model_calls = self.trainer.model.num_calls
- num_crit_calls = self.trainer.criterion.num_calls
- self.assertEqual(num_model_calls, num_crit_calls)
- for num_calls in [num_model_calls, num_crit_calls]:
- lower_bound = OptimizerMock.min_evals * self.num_iters
- upper_bound = OptimizerMock.max_evals * self.num_iters
- self.assertEqual(num_calls, self.trainer.optimizer.num_evals)
- self.assertLessEqual(lower_bound, num_calls)
- self.assertLessEqual(num_calls, upper_bound)
-
- def test_model_gradient(self):
- self.trainer.run(epochs=self.num_epochs)
- output_var = self.trainer.model.output
- expected_grad = torch.ones(1, 1) * 2 * self.optimizer.num_evals
- self.assertEqual(output_var.grad.data, expected_grad)
-
-
test_dir = os.path.abspath(os.path.dirname(str(__file__)))
diff --git a/torch/utils/trainer/__init__.py b/torch/utils/trainer/__init__.py
deleted file mode 100644
index 29340cf..0000000
--- a/torch/utils/trainer/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-
-from .trainer import Trainer
diff --git a/torch/utils/trainer/plugins/__init__.py b/torch/utils/trainer/plugins/__init__.py
deleted file mode 100644
index e8d10f4..0000000
--- a/torch/utils/trainer/plugins/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .progress import ProgressMonitor
-from .accuracy import AccuracyMonitor
-from .time import TimeMonitor
-from .loss import LossMonitor
-from .logger import Logger
diff --git a/torch/utils/trainer/plugins/accuracy.py b/torch/utils/trainer/plugins/accuracy.py
deleted file mode 100644
index f6f393c..0000000
--- a/torch/utils/trainer/plugins/accuracy.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from .monitor import Monitor
-
-
-class AccuracyMonitor(Monitor):
- stat_name = 'accuracy'
-
- def __init__(self, *args, **kwargs):
- kwargs.setdefault('unit', '%')
- kwargs.setdefault('precision', 2)
- super(AccuracyMonitor, self).__init__(*args, **kwargs)
-
- def _get_value(self, iteration, input, target, output, loss):
- batch_size = input.size(0)
- predictions = output.max(1)[1].type_as(target)
- correct = predictions.eq(target)
- if not hasattr(correct, 'sum'):
- correct = correct.cpu()
- correct = correct.sum()
- return 100. * correct / batch_size
diff --git a/torch/utils/trainer/plugins/logger.py b/torch/utils/trainer/plugins/logger.py
deleted file mode 100644
index 9bc2dfc..0000000
--- a/torch/utils/trainer/plugins/logger.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from collections import defaultdict
-from .plugin import Plugin
-
-
-class Logger(Plugin):
- alignment = 4
- separator = '#' * 80
-
- def __init__(self, fields, interval=None):
- if interval is None:
- interval = [(1, 'iteration'), (1, 'epoch')]
- super(Logger, self).__init__(interval)
- self.field_widths = defaultdict(lambda: defaultdict(int))
- self.fields = list(map(lambda f: f.split('.'), fields))
-
- def _join_results(self, results):
- joined_out = map(lambda i: (i[0], ' '.join(i[1])), results)
- joined_fields = map(lambda i: '{}: {}'.format(i[0], i[1]), joined_out)
- return '\t'.join(joined_fields)
-
- def log(self, msg):
- print(msg)
-
- def register(self, trainer):
- self.trainer = trainer
-
- def gather_stats(self):
- result = {}
- return result
-
- def _align_output(self, field_idx, output):
- for output_idx, o in enumerate(output):
- if len(o) < self.field_widths[field_idx][output_idx]:
- num_spaces = self.field_widths[field_idx][output_idx] - len(o)
- output[output_idx] += ' ' * num_spaces
- else:
- self.field_widths[field_idx][output_idx] = len(o)
-
- def _gather_outputs(self, field, log_fields, stat_parent, stat, require_dict=False):
- output = []
- name = ''
- if isinstance(stat, dict):
- log_fields = stat.get(log_fields, [])
- name = stat.get('log_name', '.'.join(field))
- for f in log_fields:
- output.append(f.format(**stat))
- elif not require_dict:
- name = '.'.join(field)
- number_format = stat_parent.get('log_format', '')
- unit = stat_parent.get('log_unit', '')
- fmt = '{' + number_format + '}' + unit
- output.append(fmt.format(stat))
- return name, output
-
- def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False):
- results = []
- for field_idx, field in enumerate(self.fields):
- parent, stat = None, self.trainer.stats
- for f in field:
- parent, stat = stat, stat[f]
- name, output = self._gather_outputs(field, log_fields,
- parent, stat, require_dict)
- if not output:
- continue
- self._align_output(field_idx, output)
- results.append((name, output))
- if not results:
- return
- output = self._join_results(results)
- if prefix is not None:
- self.log(prefix)
- self.log(output)
- if suffix is not None:
- self.log(suffix)
-
- def iteration(self, *args):
- self._log_all('log_iter_fields')
-
- def epoch(self, epoch_idx):
- self._log_all('log_epoch_fields',
- prefix=self.separator + '\nEpoch summary:',
- suffix=self.separator,
- require_dict=True)
diff --git a/torch/utils/trainer/plugins/loss.py b/torch/utils/trainer/plugins/loss.py
deleted file mode 100644
index 1bd93f2..0000000
--- a/torch/utils/trainer/plugins/loss.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .monitor import Monitor
-
-
-class LossMonitor(Monitor):
- stat_name = 'loss'
-
- def _get_value(self, iteration, input, target, output, loss):
- return loss.item()
diff --git a/torch/utils/trainer/plugins/monitor.py b/torch/utils/trainer/plugins/monitor.py
deleted file mode 100644
index b1e1d9f..0000000
--- a/torch/utils/trainer/plugins/monitor.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from .plugin import Plugin
-
-
-class Monitor(Plugin):
-
- def __init__(self, running_average=True, epoch_average=True, smoothing=0.7,
- precision=None, number_format=None, unit=''):
- if precision is None:
- precision = 4
- if number_format is None:
- number_format = '.{}f'.format(precision)
- number_format = ':' + number_format
- super(Monitor, self).__init__([(1, 'iteration'), (1, 'epoch')])
-
- self.smoothing = smoothing
- self.with_running_average = running_average
- self.with_epoch_average = epoch_average
-
- self.log_format = number_format
- self.log_unit = unit
- self.log_epoch_fields = None
- self.log_iter_fields = ['{last' + number_format + '}' + unit]
- if self.with_running_average:
- self.log_iter_fields += [' ({running_avg' + number_format + '}' + unit + ')']
- if self.with_epoch_average:
- self.log_epoch_fields = ['{epoch_mean' + number_format + '}' + unit]
-
- def register(self, trainer):
- self.trainer = trainer
- stats = self.trainer.stats.setdefault(self.stat_name, {})
- stats['log_format'] = self.log_format
- stats['log_unit'] = self.log_unit
- stats['log_iter_fields'] = self.log_iter_fields
- if self.with_epoch_average:
- stats['log_epoch_fields'] = self.log_epoch_fields
- if self.with_epoch_average:
- stats['epoch_stats'] = (0, 0)
-
- def iteration(self, *args):
- stats = self.trainer.stats.setdefault(self.stat_name, {})
- stats['last'] = self._get_value(*args)
-
- if self.with_epoch_average:
- stats['epoch_stats'] = tuple(sum(t) for t in
- zip(stats['epoch_stats'], (stats['last'], 1)))
-
- if self.with_running_average:
- previous_avg = stats.get('running_avg', 0)
- stats['running_avg'] = previous_avg * self.smoothing + \
- stats['last'] * (1 - self.smoothing)
-
- def epoch(self, idx):
- stats = self.trainer.stats.setdefault(self.stat_name, {})
- if self.with_epoch_average:
- epoch_stats = stats['epoch_stats']
- stats['epoch_mean'] = epoch_stats[0] / epoch_stats[1]
- stats['epoch_stats'] = (0, 0)
diff --git a/torch/utils/trainer/plugins/plugin.py b/torch/utils/trainer/plugins/plugin.py
deleted file mode 100644
index e1ac251..0000000
--- a/torch/utils/trainer/plugins/plugin.py
+++ /dev/null
@@ -1,10 +0,0 @@
-
-class Plugin(object):
-
- def __init__(self, interval=None):
- if interval is None:
- interval = []
- self.trigger_interval = interval
-
- def register(self, trainer):
- raise NotImplementedError
diff --git a/torch/utils/trainer/plugins/progress.py b/torch/utils/trainer/plugins/progress.py
deleted file mode 100644
index 5b0dc2e..0000000
--- a/torch/utils/trainer/plugins/progress.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from .plugin import Plugin
-
-
-class ProgressMonitor(Plugin):
- stat_name = 'progress'
-
- def __init__(self):
- super(ProgressMonitor, self).__init__([(1, 'iteration'), (1, 'epoch')])
-
- def register(self, trainer):
- self.trainer = trainer
- stats = self.trainer.stats.setdefault(self.stat_name, {})
- stats['samples_used'] = 0
- stats['epoch_size'] = len(trainer.dataset)
- stats['log_iter_fields'] = [
- '{samples_used}/{epoch_size}',
- '({percent:.2f}%)'
- ]
-
- def iteration(self, iteration, input, *args):
- stats = self.trainer.stats.setdefault(self.stat_name, {})
- stats['samples_used'] += 1
- stats['percent'] = 100. * stats['samples_used'] / stats['epoch_size']
-
- def epoch(self, *args):
- stats = self.trainer.stats.setdefault(self.stat_name, {})
- stats['samples_used'] = 0
- stats['percent'] = 0
diff --git a/torch/utils/trainer/plugins/time.py b/torch/utils/trainer/plugins/time.py
deleted file mode 100644
index ffdc198..0000000
--- a/torch/utils/trainer/plugins/time.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from __future__ import absolute_import
-import time
-
-from .monitor import Monitor
-
-
-class TimeMonitor(Monitor):
- stat_name = 'time'
-
- def __init__(self, *args, **kwargs):
- kwargs.setdefault('unit', 'ms')
- kwargs.setdefault('precision', 0)
- super(TimeMonitor, self).__init__(*args, **kwargs)
- self.last_time = None
-
- def _get_value(self, *args):
- if self.last_time:
- now = time.time()
- duration = now - self.last_time
- self.last_time = now
- return duration * 1000
- else:
- self.last_time = time.time()
- return 0
diff --git a/torch/utils/trainer/trainer.py b/torch/utils/trainer/trainer.py
deleted file mode 100644
index 3203a75..0000000
--- a/torch/utils/trainer/trainer.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import heapq
-
-
-class Trainer(object):
-
- def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
- self.model = model
- self.criterion = criterion
- self.optimizer = optimizer
- self.dataset = dataset
- self.iterations = 0
- self.stats = {}
- self.plugin_queues = {
- 'iteration': [],
- 'epoch': [],
- 'batch': [],
- 'update': [],
- }
-
- def register_plugin(self, plugin):
- plugin.register(self)
-
- intervals = plugin.trigger_interval
- if not isinstance(intervals, list):
- intervals = [intervals]
- for duration, unit in intervals:
- queue = self.plugin_queues[unit]
- queue.append((duration, len(queue), plugin))
-
- def call_plugins(self, queue_name, time, *args):
- args = (time,) + args
- queue = self.plugin_queues[queue_name]
- if len(queue) == 0:
- return
- while queue[0][0] <= time:
- plugin = queue[0][2]
- getattr(plugin, queue_name)(*args)
- for trigger in plugin.trigger_interval:
- if trigger[1] == queue_name:
- interval = trigger[0]
- new_item = (time + interval, queue[0][1], plugin)
- heapq.heappushpop(queue, new_item)
-
- def run(self, epochs=1):
- for q in self.plugin_queues.values():
- heapq.heapify(q)
-
- for i in range(1, epochs + 1):
- self.train()
- self.call_plugins('epoch', i)
-
- def train(self):
- for i, data in enumerate(self.dataset, self.iterations + 1):
- batch_input, batch_target = data
- self.call_plugins('batch', i, batch_input, batch_target)
- input_var = batch_input
- target_var = batch_target
-
- plugin_data = [None, None]
-
- def closure():
- batch_output = self.model(input_var)
- loss = self.criterion(batch_output, target_var)
- loss.backward()
- if plugin_data[0] is None:
- plugin_data[0] = batch_output.data
- plugin_data[1] = loss.data
- return loss
-
- self.optimizer.zero_grad()
- self.optimizer.step(closure)
- self.call_plugins('iteration', i, batch_input, batch_target,
- *plugin_data)
- self.call_plugins('update', i, self.model)
-
- self.iterations += i