[pt][fbgemm] Turn on USE_FBGEMM on Windows env (#297)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/297
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33250
As Title says. FBGEMM has recently added the support for Windows.
ghstack-source-id: 97932881
Test Plan: CI
Reviewed By: jspark1105
Differential Revision: D19738268
fbshipit-source-id: e7f3c91f033018f6355edeaf6003bd2803119df4
diff --git a/.circleci/scripts/vs_install.ps1 b/.circleci/scripts/vs_install.ps1
index ed47ad2..6bbb1de 100644
--- a/.circleci/scripts/vs_install.ps1
+++ b/.circleci/scripts/vs_install.ps1
@@ -8,7 +8,6 @@
"--add Microsoft.VisualStudio.Component.VC.Redist.14.Latest",
"--add Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Core",
"--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
- "--add Microsoft.VisualStudio.Component.VC.Tools.14.11",
"--add Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Win81")
curl.exe --retry 3 -kL $VS_DOWNLOAD_LINK --output vs_installer.exe
diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp
index 6cd562b..69af5d6 100644
--- a/aten/src/ATen/native/QuantizedLinear.cpp
+++ b/aten/src/ATen/native/QuantizedLinear.cpp
@@ -301,7 +301,8 @@
const unsigned short significand_bits = value & 0x3ff;
const float sign = sign_bits ? -1 : 1;
- const float significand = 1 + significand_bits * 0x1p-10;
+ const float significand =
+ 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10
const float exponent = exponent_bits - 0xf;
return sign * std::ldexp(significand, exponent);
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
index b96941d..70ff4dc 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
@@ -17,7 +17,7 @@
// of the A rows. The column offsets are needed for the asymmetric quantization
// (affine quantization) of input matrix.
// Note that in JIT mode we can think of a way to fuse col_offsets with bias.
-struct FBGEMM_API PackedLinearWeight {
+struct CAFFE2_API PackedLinearWeight {
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
c10::optional<at::Tensor> bias;
std::vector<int32_t> col_offsets;
@@ -26,13 +26,13 @@
c10::QScheme q_scheme;
};
-struct FBGEMM_API PackedLinearWeightFp16 {
+struct CAFFE2_API PackedLinearWeightFp16 {
std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w;
c10::optional<at::Tensor> bias;
};
template <int kSpatialDim = 2>
-struct FBGEMM_API PackedConvWeight {
+struct CAFFE2_API PackedConvWeight {
std::unique_ptr<fbgemm::PackWeightsForConv<kSpatialDim>> w;
c10::optional<at::Tensor> bias;
std::vector<int32_t> col_offsets;
diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
index 58c64c1..f153a8a 100644
--- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
@@ -257,7 +257,8 @@
const unsigned short significand_bits = value & 0x3ff;
const float sign = sign_bits ? -1 : 1;
- const float significand = 1 + significand_bits * 0x1p-10;
+ const float significand = 1 +
+ significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
const float exponent = exponent_bits - 0xf;
return sign * std::ldexp(significand, exponent);
diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp
index f6b90dc..732b835 100644
--- a/aten/src/ATen/quantized/Quantizer.cpp
+++ b/aten/src/ATen/quantized/Quantizer.cpp
@@ -133,10 +133,9 @@
template <typename T>
inline float dequantize_val(double scale, int64_t zero_point, T value) {
- fbgemm::TensorQuantizationParams qparams = {
- .scale = static_cast<float>(scale),
- .zero_point = static_cast<int32_t>(zero_point)
- };
+ fbgemm::TensorQuantizationParams qparams;
+ qparams.scale = static_cast<float>(scale);
+ qparams.zero_point = static_cast<int32_t>(zero_point);
return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
}
diff --git a/caffe2/quantization/server/CMakeLists.txt b/caffe2/quantization/server/CMakeLists.txt
index e48fab3..24d3ae2 100644
--- a/caffe2/quantization/server/CMakeLists.txt
+++ b/caffe2/quantization/server/CMakeLists.txt
@@ -60,13 +60,20 @@
#"${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_test.cc")
#"${CMAKE_CURRENT_SOURCE_DIR}/tanh_test.cc")
-if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
+if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
add_library(caffe2_dnnlowp_avx2_ops OBJECT ${caffe2_dnnlowp_avx2_ops_SRCS})
add_dependencies(caffe2_dnnlowp_avx2_ops fbgemm Caffe2_PROTO c10)
target_include_directories(caffe2_dnnlowp_avx2_ops BEFORE
PRIVATE $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>)
- set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
- APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx2 -mfma -mf16c -mxsave ")
+
+ if (MSVC)
+ set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
+ APPEND_STRING PROPERTY COMPILE_FLAGS " /arch:AVX2 ")
+ else()
+ set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
+ APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx2 -mfma -mf16c -mxsave ")
+ endif()
+
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
$<TARGET_OBJECTS:caffe2_dnnlowp_avx2_ops>)
endif()
diff --git a/caffe2/quantization/server/activation_distribution_observer.h b/caffe2/quantization/server/activation_distribution_observer.h
index 3d9249c..02a6de6 100644
--- a/caffe2/quantization/server/activation_distribution_observer.h
+++ b/caffe2/quantization/server/activation_distribution_observer.h
@@ -221,4 +221,14 @@
const std::string& qparams_output_file_name = "");
};
+#ifdef _MSC_VER
+struct tm* localtime_r(time_t* _clock, struct tm* _result) {
+ struct tm* candidate_result = localtime(_clock);
+ if (candidate_result) {
+ *(_result) = *candidate_result;
+ }
+ return candidate_result;
+}
+#endif
+
} // namespace caffe2
diff --git a/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc b/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc
index 2df5076..bfec3e1 100644
--- a/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc
+++ b/caffe2/quantization/server/batch_matmul_dnnlowp_op.cc
@@ -430,8 +430,12 @@
Y_int32_.resize(Y->numel());
#ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
#pragma omp parallel for collapse(2)
#endif
+#endif
for (int p = 0; p < num_outer_batches; ++p) {
for (int i = 0; i < num_sub_batches; ++i) {
int tid = dnnlowp_get_thread_num();
@@ -489,8 +493,12 @@
A_pack_buf_.resize(A_pack_len_per_thread * dnnlowp_get_max_threads());
#ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
#pragma omp parallel for collapse(2)
#endif
+#endif
for (int p = 0; p < num_outer_batches; ++p) {
for (int i = 0; i < num_sub_batches; ++i) {
int tid = dnnlowp_get_thread_num();
@@ -544,8 +552,12 @@
A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
#ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
#pragma omp parallel for collapse(2)
#endif
+#endif
for (int p = 0; p < num_outer_batches; ++p) {
for (int i = 0; i < num_sub_batches; ++i) {
int tid = dnnlowp_get_thread_num();
@@ -610,8 +622,12 @@
T* Y_quantized = GetQuantizedOutputData_();
Y_int32_.resize(Y->numel());
#ifdef _OPENMP
+#ifdef _MSC_VER
+#pragma omp parallel for
+#else
#pragma omp parallel for collapse(2)
#endif
+#endif
for (int p = 0; p < num_outer_batches; ++p) {
for (int i = 0; i < num_sub_batches; ++i) {
// Y_q = (scale_A * scale_B) / scale_Y * Y_int32
diff --git a/caffe2/quantization/server/conv_dnnlowp_op.cc b/caffe2/quantization/server/conv_dnnlowp_op.cc
index 3ebb2df..2821073 100644
--- a/caffe2/quantization/server/conv_dnnlowp_op.cc
+++ b/caffe2/quantization/server/conv_dnnlowp_op.cc
@@ -860,7 +860,7 @@
int32_t Y_min = numeric_limits<int32_t>::max();
int32_t Y_max = numeric_limits<int32_t>::min();
-#ifdef _OPENMP
+#if defined(_OPENMP) && !defined(_MSC_VER)
#pragma omp parallel for reduction(min : Y_min), reduction(max : Y_max)
#endif
for (int i = 0; i < N * Y_HxW; ++i) {
diff --git a/caffe2/quantization/server/dnnlowp.h b/caffe2/quantization/server/dnnlowp.h
index 0d7414f..2f68d15 100644
--- a/caffe2/quantization/server/dnnlowp.h
+++ b/caffe2/quantization/server/dnnlowp.h
@@ -6,7 +6,7 @@
#include <cstdint>
#include <limits>
-#include <x86intrin.h>
+#include <immintrin.h>
#include <fbgemm/QuantUtils.h>
diff --git a/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc b/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc
index 450e641..457d702 100644
--- a/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc
+++ b/caffe2/quantization/server/elementwise_sum_dnnlowp_op.cc
@@ -3,6 +3,7 @@
#include <array>
#include <tuple>
#include <type_traits>
+#include <vector>
// #define DNNLOWP_MEASURE_TIME_BREAKDOWN
#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -102,8 +103,8 @@
out_qparams_.zero_point);
} // omp parallel
} else {
- RequantizationParams in_requantization_params[InputSize()];
- const T* input_data[InputSize()];
+ vector<RequantizationParams> in_requantization_params(InputSize());
+ vector<T*> input_data(InputSize());
for (int i = 0; i < InputSize(); ++i) {
float real_multiplier =
in_qparams_[i].scale / intermediate_qparams_.scale;
@@ -136,9 +137,8 @@
}
}
}
- } // InputTensorCPU_(0).template IsType<T>()
- else {
- const float* input_data[InputSize()];
+ } else { // InputTensorCPU_(0).template IsType<T>()
+ vector<float*> input_data(InputSize());
for (int i = 0; i < InputSize(); ++i) {
input_data[i] = InputTensorCPU_(i).template data<float>();
}
@@ -155,7 +155,7 @@
int32_t acc = 0;
for (int i = 0; i < InputSize(); ++i) {
acc += fbgemm::Quantize<int32_t>(
- ((const float*)input_data[i])[j],
+ input_data[i][j],
intermediate_qparams_.zero_point,
intermediate_qparams_.scale,
qfactory_->GetEltwiseQuantizePrecision());
diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc
index 085a41c..f1dc732 100644
--- a/caffe2/quantization/server/fbgemm_pack_op.cc
+++ b/caffe2/quantization/server/fbgemm_pack_op.cc
@@ -736,7 +736,8 @@
std::vector<int32_t> offsets;
for (const auto v : dnntensor.qparams) {
scales.push_back(v.scale);
- offsets.push_back(reinterpret_cast<int32_t>(v.zero_point));
+ int32_t cur_offset = v.zero_point;
+ offsets.push_back(cur_offset);
}
all_scales->push_back(scales);
all_offsets->push_back(offsets);
@@ -784,7 +785,8 @@
std::vector<int32_t> offsets;
for (const auto v : dnntensor.qparams) {
scales.push_back(v.scale);
- offsets.push_back(reinterpret_cast<int32_t>(v.zero_point));
+ int32_t cur_offset = v.zero_point;
+ offsets.push_back(cur_offset);
}
all_scales->push_back(scales);
all_offsets->push_back(offsets);
diff --git a/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc b/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc
index c66bb97..7965e57 100644
--- a/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc
+++ b/caffe2/quantization/server/fully_connected_fake_lowp_op_avx2.cc
@@ -6,6 +6,7 @@
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
+// clang-format off
alignas(64) const int ld_st_masks[8][8] = {
{ 0, 0, 0, 0, 0, 0, 0, 0, },
{ -1, 0, 0, 0, 0, 0, 0, 0, },
@@ -16,6 +17,7 @@
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
+// clang-format on
} // anonymous namespace
diff --git a/caffe2/quantization/server/norm_minimization.cc b/caffe2/quantization/server/norm_minimization.cc
index ac53daa..94e655e 100644
--- a/caffe2/quantization/server/norm_minimization.cc
+++ b/caffe2/quantization/server/norm_minimization.cc
@@ -6,7 +6,7 @@
#include <cmath>
#include <limits>
-#include <x86intrin.h>
+#include <immintrin.h>
using namespace std;
@@ -160,8 +160,8 @@
start_bin = next_start_bin;
end_bin = next_end_bin;
}
- VLOG(2) << "best quantization range " << start_bin << "," << end_bin + 1 << ","
- << norm_min;
+ VLOG(2) << "best quantization range " << start_bin << "," << end_bin + 1
+ << "," << norm_min;
double selected_sum = 0;
for (int i = start_bin; i < end_bin + 1; ++i) {
diff --git a/caffe2/quantization/server/transpose.cc b/caffe2/quantization/server/transpose.cc
index ac53707..4cf3e36 100644
--- a/caffe2/quantization/server/transpose.cc
+++ b/caffe2/quantization/server/transpose.cc
@@ -1,6 +1,6 @@
#include "transpose.h"
-#include <x86intrin.h>
+#include <immintrin.h>
namespace fbgemm {
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 03f7b0e..7618a3c 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -523,17 +523,14 @@
"Turn this warning off by USE_FBGEMM=OFF.")
set(USE_FBGEMM OFF)
endif()
- if(MSVC)
- message(WARNING
- "FBGEMM is currently not supported on windows with MSVC. "
- "Not compiling with FBGEMM. "
- "Turn this warning off by USE_FBGEMM=OFF.")
- set(USE_FBGEMM OFF)
- endif()
if(USE_FBGEMM AND NOT TARGET fbgemm)
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "")
set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "")
- set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "")
+ if(MSVC AND BUILD_SHARED_LIBS)
+ set(FBGEMM_LIBRARY_TYPE "shared" CACHE STRING "")
+ else()
+ set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "")
+ endif()
add_subdirectory("${FBGEMM_SOURCE_DIR}")
set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON)