[pt][fbgemm] Turn on USE_FBGEMM on Windows env (#297)

Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/33250

As Title says. FBGEMM has recently added the support for Windows.

ghstack-source-id: 97932881

Test Plan: CI

Reviewed By: jspark1105

Differential Revision: D19738268

fbshipit-source-id: e7f3c91f033018f6355edeaf6003bd2803119df4
diff --git a/.circleci/scripts/vs_install.ps1 b/.circleci/scripts/vs_install.ps1
index ed47ad2..6bbb1de 100644
--- a/.circleci/scripts/vs_install.ps1
+++ b/.circleci/scripts/vs_install.ps1
@@ -8,7 +8,6 @@
                                                      "--add Microsoft.VisualStudio.Component.VC.Redist.14.Latest",
                                                      "--add Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Core",
                                                      "--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
-                                                     "--add Microsoft.VisualStudio.Component.VC.Tools.14.11",
                                                      "--add Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Win81")
 
 curl.exe --retry 3 -kL $VS_DOWNLOAD_LINK --output vs_installer.exe
diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp
index 6cd562b..69af5d6 100644
--- a/aten/src/ATen/native/QuantizedLinear.cpp
+++ b/aten/src/ATen/native/QuantizedLinear.cpp
@@ -301,7 +301,8 @@
   const unsigned short significand_bits = value & 0x3ff;
 
   const float sign = sign_bits ? -1 : 1;
-  const float significand = 1 + significand_bits * 0x1p-10;
+  const float significand =
+      1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10
   const float exponent = exponent_bits - 0xf;
 
   return sign * std::ldexp(significand, exponent);
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
index b96941d..70ff4dc 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
@@ -17,7 +17,7 @@
 // of the A rows. The column offsets are needed for the asymmetric quantization
 // (affine quantization) of input matrix.
 // Note that in JIT mode we can think of a way to fuse col_offsets with bias.
-struct FBGEMM_API PackedLinearWeight {
+struct CAFFE2_API PackedLinearWeight {
   std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
   c10::optional<at::Tensor> bias;
   std::vector<int32_t> col_offsets;
@@ -26,13 +26,13 @@
   c10::QScheme q_scheme;
 };
 
-struct FBGEMM_API PackedLinearWeightFp16 {
+struct CAFFE2_API PackedLinearWeightFp16 {
   std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w;
   c10::optional<at::Tensor> bias;
 };
 
 template <int kSpatialDim = 2>
-struct FBGEMM_API PackedConvWeight {
+struct CAFFE2_API PackedConvWeight {
   std::unique_ptr<fbgemm::PackWeightsForConv<kSpatialDim>> w;
   c10::optional<at::Tensor> bias;
   std::vector<int32_t> col_offsets;
diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
index 58c64c1..f153a8a 100644
--- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
@@ -257,7 +257,8 @@
     const unsigned short significand_bits = value & 0x3ff;
 
     const float sign = sign_bits ? -1 : 1;
-    const float significand = 1 + significand_bits * 0x1p-10;
+    const float significand = 1 +
+        significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
     const float exponent = exponent_bits - 0xf;
 
     return sign * std::ldexp(significand, exponent);
diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp
index f6b90dc..732b835 100644
--- a/aten/src/ATen/quantized/Quantizer.cpp
+++ b/aten/src/ATen/quantized/Quantizer.cpp
@@ -133,10 +133,9 @@
 
 template <typename T>
 inline float dequantize_val(double scale, int64_t zero_point, T value) {
-  fbgemm::TensorQuantizationParams qparams = {
-    .scale = static_cast<float>(scale),
-    .zero_point = static_cast<int32_t>(zero_point)
-  };
+  fbgemm::TensorQuantizationParams qparams;
+  qparams.scale = static_cast<float>(scale);
+  qparams.zero_point = static_cast<int32_t>(zero_point);
   return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
 }
 
diff --git a/caffe2/quantization/server/CMakeLists.txt b/caffe2/quantization/server/CMakeLists.txt
index e48fab3..24d3ae2 100644
--- a/caffe2/quantization/server/CMakeLists.txt
+++ b/caffe2/quantization/server/CMakeLists.txt
@@ -60,13 +60,20 @@
   #"${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_test.cc")
   #"${CMAKE_CURRENT_SOURCE_DIR}/tanh_test.cc")
 
-if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
+if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
   add_library(caffe2_dnnlowp_avx2_ops OBJECT ${caffe2_dnnlowp_avx2_ops_SRCS})
   add_dependencies(caffe2_dnnlowp_avx2_ops fbgemm Caffe2_PROTO c10)
   target_include_directories(caffe2_dnnlowp_avx2_ops BEFORE
     PRIVATE $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>)
-  set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
-    APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx2 -mfma -mf16c -mxsave ")
+
+  if (MSVC)
+    set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
+      APPEND_STRING PROPERTY COMPILE_FLAGS " /arch:AVX2 ")
+  else()
+    set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
+      APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx2 -mfma -mf16c -mxsave ")
+  endif()
+
   set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
     $<TARGET_OBJECTS:caffe2_dnnlowp_avx2_ops>)
 endif()
diff --git a/caffe2/quantization/server/activation_distribution_observer.h b/caffe2/quantization/server/activation_distribution_observer.h
index 3d9249c..02a6de6 100644
--- a/caffe2/quantization/server/activation_distribution_observer.h
+++ b/caffe2/quantization/server/activation_distribution_observer.h
@@ -221,4 +221,14 @@
       const std::string& qparams_output_file_name = "");
 };
 
+#ifdef _MSC_VER
+struct tm* localtime_r(time_t* _clock, struct tm* _result) {
+  struct tm* candidate_result = localtime(_clock);
+  if (candidate_result) {
+    *(_result) = *candidate_result;
+  }
+  return candidate_result;
+}
+#endif
+
 } // namespace caffe2
diff --git a/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc b/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc
index 2df5076..bfec3e1 100644
--- a/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc
+++ b/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc
@@ -430,8 +430,12 @@
       Y_int32_.resize(Y->numel());
 
 #ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
 #pragma omp parallel for collapse(2)
 #endif
+#endif
       for (int p = 0; p < num_outer_batches; ++p) {
         for (int i = 0; i < num_sub_batches; ++i) {
           int tid = dnnlowp_get_thread_num();
@@ -489,8 +493,12 @@
         A_pack_buf_.resize(A_pack_len_per_thread * dnnlowp_get_max_threads());
 
 #ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
 #pragma omp parallel for collapse(2)
 #endif
+#endif
         for (int p = 0; p < num_outer_batches; ++p) {
           for (int i = 0; i < num_sub_batches; ++i) {
             int tid = dnnlowp_get_thread_num();
@@ -544,8 +552,12 @@
             A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
 
 #ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
 #pragma omp parallel for collapse(2)
 #endif
+#endif
         for (int p = 0; p < num_outer_batches; ++p) {
           for (int i = 0; i < num_sub_batches; ++i) {
             int tid = dnnlowp_get_thread_num();
@@ -610,8 +622,12 @@
     T* Y_quantized = GetQuantizedOutputData_();
     Y_int32_.resize(Y->numel());
 #ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
 #pragma omp parallel for collapse(2)
 #endif
+#endif
     for (int p = 0; p < num_outer_batches; ++p) {
       for (int i = 0; i < num_sub_batches; ++i) {
         // Y_q = (scale_A * scale_B) / scale_Y * Y_int32
diff --git a/caffe2/quantization/server/conv_dnnlowp_op.cc b/caffe2/quantization/server/conv_dnnlowp_op.cc
index 3ebb2df..2821073 100644
--- a/caffe2/quantization/server/conv_dnnlowp_op.cc
+++ b/caffe2/quantization/server/conv_dnnlowp_op.cc
@@ -860,7 +860,7 @@
     int32_t Y_min = numeric_limits<int32_t>::max();
     int32_t Y_max = numeric_limits<int32_t>::min();
 
-#ifdef _OPENMP
+#if defined(_OPENMP) && !defined(_MSC_VER)
 #pragma omp parallel for reduction(min : Y_min), reduction(max : Y_max)
 #endif
     for (int i = 0; i < N * Y_HxW; ++i) {
diff --git a/caffe2/quantization/server/dnnlowp.h b/caffe2/quantization/server/dnnlowp.h
index 0d7414f..2f68d15 100644
--- a/caffe2/quantization/server/dnnlowp.h
+++ b/caffe2/quantization/server/dnnlowp.h
@@ -6,7 +6,7 @@
 #include <cstdint>
 #include <limits>
 
-#include <x86intrin.h>
+#include <immintrin.h>
 
 #include <fbgemm/QuantUtils.h>
 
diff --git a/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc b/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc
index 450e641..457d702 100644
--- a/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc
+++ b/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc
@@ -3,6 +3,7 @@
 #include <array>
 #include <tuple>
 #include <type_traits>
+#include <vector>
 
 // #define DNNLOWP_MEASURE_TIME_BREAKDOWN
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -102,8 +103,8 @@
             out_qparams_.zero_point);
       } // omp parallel
     } else {
-      RequantizationParams in_requantization_params[InputSize()];
-      const T* input_data[InputSize()];
+      vector<RequantizationParams> in_requantization_params(InputSize());
+      vector<T*> input_data(InputSize());
       for (int i = 0; i < InputSize(); ++i) {
         float real_multiplier =
             in_qparams_[i].scale / intermediate_qparams_.scale;
@@ -136,9 +137,8 @@
         }
       }
     }
-  } // InputTensorCPU_(0).template IsType<T>()
-  else {
-    const float* input_data[InputSize()];
+  } else { // InputTensorCPU_(0).template IsType<T>()
+    vector<float*> input_data(InputSize());
     for (int i = 0; i < InputSize(); ++i) {
       input_data[i] = InputTensorCPU_(i).template data<float>();
     }
@@ -155,7 +155,7 @@
         int32_t acc = 0;
         for (int i = 0; i < InputSize(); ++i) {
           acc += fbgemm::Quantize<int32_t>(
-              ((const float*)input_data[i])[j],
+              input_data[i][j],
               intermediate_qparams_.zero_point,
               intermediate_qparams_.scale,
               qfactory_->GetEltwiseQuantizePrecision());
diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc
index 085a41c..f1dc732 100644
--- a/caffe2/quantization/server/fbgemm_pack_op.cc
+++ b/caffe2/quantization/server/fbgemm_pack_op.cc
@@ -736,7 +736,8 @@
   std::vector<int32_t> offsets;
   for (const auto v : dnntensor.qparams) {
     scales.push_back(v.scale);
-    offsets.push_back(reinterpret_cast<int32_t>(v.zero_point));
+    int32_t cur_offset = v.zero_point;
+    offsets.push_back(cur_offset);
   }
   all_scales->push_back(scales);
   all_offsets->push_back(offsets);
@@ -784,7 +785,8 @@
   std::vector<int32_t> offsets;
   for (const auto v : dnntensor.qparams) {
     scales.push_back(v.scale);
-    offsets.push_back(reinterpret_cast<int32_t>(v.zero_point));
+    int32_t cur_offset = v.zero_point;
+    offsets.push_back(cur_offset);
   }
   all_scales->push_back(scales);
   all_offsets->push_back(offsets);
diff --git a/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc b/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc
index c66bb97..7965e57 100644
--- a/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc
+++ b/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc
@@ -6,6 +6,7 @@
 
 // NOTE: clang-format wants to use a different formatting but the
 // current formatting should be easier to read.
+// clang-format off
 alignas(64) const int ld_st_masks[8][8] = {
   {  0,  0,  0,  0,  0,  0,  0,  0, },
   { -1,  0,  0,  0,  0,  0,  0,  0, },
@@ -16,6 +17,7 @@
   { -1, -1, -1, -1, -1, -1,  0,  0, },
   { -1, -1, -1, -1, -1, -1, -1,  0, },
 };
+// clang-format on
 
 } // anonymous namespace
 
diff --git a/caffe2/quantization/server/norm_minimization.cc b/caffe2/quantization/server/norm_minimization.cc
index ac53daa..94e655e 100644
--- a/caffe2/quantization/server/norm_minimization.cc
+++ b/caffe2/quantization/server/norm_minimization.cc
@@ -6,7 +6,7 @@
 #include <cmath>
 #include <limits>
 
-#include <x86intrin.h>
+#include <immintrin.h>
 
 using namespace std;
 
@@ -160,8 +160,8 @@
     start_bin = next_start_bin;
     end_bin = next_end_bin;
   }
-  VLOG(2) << "best quantization range " << start_bin << "," << end_bin + 1 << ","
-          << norm_min;
+  VLOG(2) << "best quantization range " << start_bin << "," << end_bin + 1
+          << "," << norm_min;
 
   double selected_sum = 0;
   for (int i = start_bin; i < end_bin + 1; ++i) {
diff --git a/caffe2/quantization/server/transpose.cc b/caffe2/quantization/server/transpose.cc
index ac53707..4cf3e36 100644
--- a/caffe2/quantization/server/transpose.cc
+++ b/caffe2/quantization/server/transpose.cc
@@ -1,6 +1,6 @@
 #include "transpose.h"
 
-#include <x86intrin.h>
+#include <immintrin.h>
 
 namespace fbgemm {
 
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 03f7b0e..7618a3c 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -523,17 +523,14 @@
       "Turn this warning off by USE_FBGEMM=OFF.")
     set(USE_FBGEMM OFF)
   endif()
-  if(MSVC)
-    message(WARNING
-      "FBGEMM is currently not supported on windows with MSVC. "
-      "Not compiling with FBGEMM. "
-      "Turn this warning off by USE_FBGEMM=OFF.")
-    set(USE_FBGEMM OFF)
-  endif()
   if(USE_FBGEMM AND NOT TARGET fbgemm)
     set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "")
     set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "")
-    set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "")
+    if(MSVC AND BUILD_SHARED_LIBS)
+      set(FBGEMM_LIBRARY_TYPE "shared" CACHE STRING "")
+    else()
+      set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "")
+    endif()
     add_subdirectory("${FBGEMM_SOURCE_DIR}")
     set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON)
     set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON)