Update resnet50_trainer example

Summary:
A few fixes in this commit: the epoch size is now rounded
down to the closest integer multiple of the global batch size (batch
per GPU * GPUs per hosts * hosts per run). The num_shards and shard_id
parameters are now passed to CreateDB so multiple processes actually
train on different subsets of data. The LR step size is scaled by the
number of hosts in the run. The test accuracy is only determined after
each epoch instead of after every so many iterations.

Differential Revision: D4871505

fbshipit-source-id: d2703dc7cf1e0f76710d9d7c09cd362a42fe0598
diff --git a/caffe2/python/examples/resnet50_trainer.py b/caffe2/python/examples/resnet50_trainer.py
index 374cf7c..fe2eb3f 100644
--- a/caffe2/python/examples/resnet50_trainer.py
+++ b/caffe2/python/examples/resnet50_trainer.py
@@ -87,6 +87,7 @@
     train_model,
     test_model,
     total_batch_size,
+    num_shards,
     expname,
     explog,
 ):
@@ -95,7 +96,7 @@
     TODO: add checkpointing here.
     '''
     # TODO: add loading from checkpoint
-    epoch_iters = int(args.epoch_size / total_batch_size)
+    epoch_iters = int(args.epoch_size / total_batch_size / num_shards)
     for i in range(epoch_iters):
         log.info("Start iteration {}/{} of epoch {}".format(
             i, epoch_iters, epoch,
@@ -107,52 +108,44 @@
         with timeout_guard.CompleteInTimeOrDie(timeout):
             workspace.RunNet(train_model.net.Proto().name)
 
-        num_images = (i + epoch * epoch_iters) * total_batch_size
-        record_freq = total_batch_size * 20
+    num_images = epoch * epoch_iters * total_batch_size
+    prefix = "gpu_{}".format(train_model._devices[0])
+    accuracy = workspace.FetchBlob(prefix + '/accuracy')
+    loss = workspace.FetchBlob(prefix + '/loss')
+    learning_rate = workspace.FetchBlob(prefix + '/LR')
+    test_accuracy = 0
+    if (test_model is not None):
+        # Run 100 iters of testing
+        ntests = 0
+        for _ in range(0, 100):
+            workspace.RunNet(test_model.net.Proto().name)
+            for g in test_model._devices:
+                test_accuracy += np.asscalar(workspace.FetchBlob(
+                    "gpu_{}".format(g) + '/accuracy'
+                ))
+                ntests += 1
+        test_accuracy /= ntests
+    else:
+        test_accuracy = (-1)
 
-        # Report progress, compute train and test accuracies.
-        if num_images % record_freq == 0 and i > 0:
-            prefix = "gpu_{}".format(train_model._devices[0])
-            accuracy = workspace.FetchBlob(prefix + '/accuracy')
-            loss = workspace.FetchBlob(prefix + '/loss')
-            learning_rate = workspace.FetchBlob(prefix + '/LR')
-
-            test_accuracy = 0
-            ntests = 0
-
-            if (test_model is not None):
-                # Run 5 iters of testing
-                for _ in range(0, 5):
-                    workspace.RunNet(test_model.net.Proto().name)
-                    for g in test_model._devices:
-                        test_accuracy += np.asscalar(workspace.FetchBlob(
-                            "gpu_{}".format(g) + '/accuracy'
-                        ))
-                        ntests += 1
-                test_accuracy /= ntests
-            else:
-                test_accuracy = (-1)
-
-            explog.log(
-                input_count=num_images,
-                batch_count=(i + epoch * epoch_iters),
-                additional_values={
-                    'accuracy': accuracy,
-                    'loss': loss,
-                    'learning_rate': learning_rate,
-                    'epoch': epoch,
-                    'test_accuracy': test_accuracy,
-                }
-            )
-            assert loss < 40, "Exploded gradients :("
+    explog.log(
+        input_count=num_images,
+        batch_count=(i + epoch * epoch_iters),
+        additional_values={
+            'accuracy': accuracy,
+            'loss': loss,
+            'learning_rate': learning_rate,
+            'epoch': epoch,
+            'test_accuracy': test_accuracy,
+        }
+    )
+    assert loss < 40, "Exploded gradients :("
 
     # TODO: add checkpointing
     return epoch + 1
 
 
 def Train(args):
-    total_batch_size = args.batch_size
-
     # Either use specified device list or generate one
     if args.gpus is not None:
         gpus = [int(x) for x in args.gpus.split(',')]
@@ -161,12 +154,20 @@
         gpus = range(args.num_gpus)
         num_gpus = args.num_gpus
 
+    log.info("Running on GPUs: {}".format(gpus))
+
+    # Verify valid batch size
+    total_batch_size = args.batch_size
     batch_per_device = total_batch_size // num_gpus
     assert \
         total_batch_size % num_gpus == 0, \
         "Number of GPUs must divide batch size"
 
-    log.info("Running on GPUs: {}".format(gpus))
+    # Round down epoch size to closest multiple of batch size across machines
+    global_batch_size = total_batch_size * args.num_shards
+    epoch_iters = int(args.epoch_size / global_batch_size)
+    args.epoch_size = epoch_iters * global_batch_size
+    log.info("Using epoch size: {}".format(args.epoch_size))
 
     # Create CNNModeLhelper object
     train_model = cnn.CNNModelHelper(
@@ -177,7 +178,9 @@
         ws_nbytes_limit=(args.cudnn_workspace_limit_mb * 1024 * 1024),
     )
 
-    if args.num_shards > 1:
+    num_shards = args.num_shards
+    shard_id = args.shard_id
+    if num_shards > 1:
         # Create rendezvous for distributed computation
         store_handler = "store_handler"
         if args.redis_host is not None:
@@ -200,8 +203,8 @@
             )
         rendezvous = dict(
             kv_handler=store_handler,
-            shard_id=args.shard_id,
-            num_shards=args.num_shards,
+            shard_id=shard_id,
+            num_shards=num_shards,
             engine="GLOO",
             exit_nets=None)
     else:
@@ -216,6 +219,7 @@
             num_input_channels=args.num_channels,
             num_labels=args.num_labels,
             label="label",
+            no_bias=True,
         )
         loss = model.Scale(loss, scale=loss_scale)
         model.Accuracy([softmax, "label"], "accuracy")
@@ -225,7 +229,7 @@
     def add_parameter_update_ops(model):
         model.AddWeightDecay(args.weight_decay)
         ITER = model.Iter("ITER")
-        stepsz = int(30 * args.epoch_size / total_batch_size)
+        stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)
         LR = model.net.LearningRate(
             [ITER],
             "LR",
@@ -241,6 +245,8 @@
         "reader",
         db=args.train_data,
         db_type=args.db_type,
+        num_shards=num_shards,
+        shard_id=shard_id,
     )
 
     def add_image_input(model):
@@ -317,6 +323,7 @@
             train_model,
             test_model,
             total_batch_size,
+            num_shards,
             expname,
             explog
         )
@@ -349,7 +356,7 @@
     parser.add_argument("--batch_size", type=int, default=32,
                         help="Batch size, total over all GPUs")
     parser.add_argument("--epoch_size", type=int, default=1500000,
-                        help="Number of images/epoch")
+                        help="Number of images/epoch, total over all machines")
     parser.add_argument("--num_epochs", type=int, default=1000,
                         help="Num epochs.")
     parser.add_argument("--base_learning_rate", type=float, default=0.1,