Add __syncthreads() between CUB reductions for elementwise linear gradient kernel

Summary: Thanks to ezyang, now I know that if a CUB tempstorage is reused, a thread sync is needed. So added this to the elementwise linear gradient kernel.

Reviewed By: wickedfoo, ezyang

Differential Revision: D4949250

fbshipit-source-id: fbbbd336a962a51be43784207105cadd391a8ef2
diff --git a/caffe2/operators/elementwise_linear_op.cu b/caffe2/operators/elementwise_linear_op.cu
index ebd9ead..18efb71 100644
--- a/caffe2/operators/elementwise_linear_op.cu
+++ b/caffe2/operators/elementwise_linear_op.cu
@@ -37,6 +37,7 @@
   __shared__ typename BlockReduce::TempStorage temp_storage;
 
   float g_a_sum_tot = BlockReduce(temp_storage).Sum(g_a_sum);
+  __syncthreads();
   float g_b_sum_tot = BlockReduce(temp_storage).Sum(g_b_sum);
 
   if (threadIdx.x == 0) {