make functional layer return scalar if only one output

Summary: This diff makes functional layer return scalar if only one output. This diff also corrects all other corresponding implementations.

Reviewed By: kittipatv

Differential Revision: D5386853

fbshipit-source-id: 1f00582f6ec23384b2a6db94e19952836755ef42
diff --git a/caffe2/python/layers/functional.py b/caffe2/python/layers/functional.py
index 71c99e7..ea4270b 100644
--- a/caffe2/python/layers/functional.py
+++ b/caffe2/python/layers/functional.py
@@ -1,4 +1,4 @@
-## @package functional
+# @package functional
 # Module caffe2.python.layers.functional
 from __future__ import absolute_import
 from __future__ import division
@@ -11,6 +11,7 @@
 )
 import caffe2.proto.caffe2_pb2 as caffe2_pb2
 import numpy as np
+import six
 import logging
 
 logger = logging.getLogger(__name__)
@@ -28,10 +29,15 @@
         super(Functional, self).__init__(model, name, input_record, **kwargs)
         self._function = function
         self._kwargs = kwargs
+        return_struct = (
+            isinstance(output_names_or_num, list) or
+            (isinstance(output_names_or_num, six.integer_types) and
+             output_names_or_num != 1)
+        )
 
         with scope.NameScope(self.name):
             if isinstance(output_names_or_num, int):
-                self.output_schema = schema.NewRecord(
+                struct_output_schema = schema.NewRecord(
                     model.net, schema.RawTuple(output_names_or_num))
             elif isinstance(output_names_or_num, schema.Field):
                 self.output_schema = output_names_or_num.clone(keep_blobs=True)
@@ -40,10 +46,17 @@
                 if not isinstance(output_names_or_num, list):
                     output_names_or_num = [output_names_or_num]
                 out_tuple = [(out, np.void) for out in output_names_or_num]
-                self.output_schema = schema.NewRecord(
+                struct_output_schema = schema.NewRecord(
                     model.net, schema.Struct(*out_tuple))
 
-        num_outputs = len(self.output_schema.field_blobs())
+        num_outputs = len(struct_output_schema.field_blobs())
+
+        # functional layer returns Struct if more than one outputs or output is
+        # a list, otherwise Scalar
+        if return_struct:
+            self.output_schema = struct_output_schema
+        else:
+            self.output_schema = struct_output_schema[0]
 
         # If output_dtypes is provided, use it for output schema. Otherwise
         # the shape and type will be inferred.
@@ -65,7 +78,9 @@
             function(type_net, self.input_record, self.output_schema, **kwargs)
             (shapes, types) = workspace.InferShapesAndTypes([type_net], {})
             for i in range(num_outputs):
-                blob = self.output_schema[i]()
+                scalar_schema = (self.output_schema[i] if return_struct
+                                 else self.output_schema)
+                blob = scalar_schema()
                 if blob not in types or blob not in shapes:
                     had_issues = True
                     continue
@@ -93,7 +108,7 @@
                     dtype = (np.int64, shape)
 
                 if dtype is not None:
-                    self.output_schema[i].set_type(dtype)
+                    scalar_schema.set_type(dtype)
         except TypeError as ex:
             had_issues = True
             logger.warning(str(ex))
diff --git a/caffe2/python/layers_test.py b/caffe2/python/layers_test.py
index a5b07d0..1614981 100644
--- a/caffe2/python/layers_test.py
+++ b/caffe2/python/layers_test.py
@@ -490,7 +490,7 @@
         schema.FeedRecord(float_features, [float_array])
 
         with Tags(Tags.EXCLUDE_FROM_PREDICTION):
-            log_float_features, = self.model.Log(float_features, 1)
+            log_float_features = self.model.Log(float_features, 1)
         joined = self.model.SelectRecordByContext(
             schema.Struct(
                 (InstantiationContext.PREDICTION, float_features),
@@ -529,15 +529,15 @@
             mean = net.ReduceFrontMean(in_record(), 1)
             net.Sub(
                 [in_record(), mean],
-                out_record[0](),
+                out_record(),
                 broadcast=1)
         normalized = self.model.Functional(
             self.model.input_feature_schema.float_features, 1,
             normalize, name="normalizer")
 
         # Attach metadata to one of the outputs and use it in FC
-        normalized[0].set_type((np.float32, 32))
-        self.model.output_schema = self.model.FC(normalized[0], 2)
+        normalized.set_type((np.float32, 32))
+        self.model.output_schema = self.model.FC(normalized, 2)
 
         predict_net = layer_model_instantiator.generate_predict_net(
             self.model)
@@ -557,11 +557,11 @@
             self.model.input_feature_schema.float_features, 1)
         normalized = self.model.Sub(
             schema.Tuple(
-                self.model.input_feature_schema.float_features, mean[0]),
+                self.model.input_feature_schema.float_features, mean),
             1, broadcast=1)
         # Attach metadata to one of the outputs and use it in FC
-        normalized[0].set_type((np.float32, (32,)))
-        self.model.output_schema = self.model.FC(normalized[0], 2)
+        normalized.set_type((np.float32, (32,)))
+        self.model.output_schema = self.model.FC(normalized, 2)
 
         predict_net = layer_model_instantiator.generate_predict_net(
             self.model)
@@ -580,10 +580,9 @@
         softsign = self.model.Softsign(
             schema.Tuple(self.model.input_feature_schema.float_features),
             1)
-        assert len(softsign.field_types()) == 1
-        assert softsign.field_types()[0].base == np.float32
-        assert softsign.field_types()[0].shape == (32,)
-        self.model.output_schema = self.model.FC(softsign[0], 2)
+        assert softsign.field_type().base == np.float32
+        assert softsign.field_type().shape == (32,)
+        self.model.output_schema = self.model.FC(softsign, 2)
 
         predict_net = layer_model_instantiator.generate_predict_net(
             self.model)