Fixes a bug in serializing min/max plus one more. (#35850)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35850
1. Clamping values were not being propagated through all the structures
and hence were not being serialized.
2. Moved to using Scalar for min/max instead of float. Reason being, the
fusion for hardtanh_ does not work. During sub graph rewrite we direct
values from hardtanh_ to preacking ops, but since they expect float
values, the types conflict and we cannot serialize the model.
Test Plan: Imported from OSS
Differential Revision: D20807523
fbshipit-source-id: 57d6b2e4b65afd9510a0f3ba9365333b768977f5
diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp
index e5eb9ad..1548ab2 100644
--- a/aten/src/ATen/native/xnnpack/Convolution.cpp
+++ b/aten/src/ATen/native/xnnpack/Convolution.cpp
@@ -229,8 +229,8 @@
std::vector<int64_t> padding,
std::vector<int64_t> dilation,
int64_t groups,
- c10::optional<double> output_min,
- c10::optional<double> output_max) {
+ c10::optional<Scalar> output_min,
+ c10::optional<Scalar> output_max) {
return xnnpack::XNNPackConv2dOpContext::create_context(
std::move(weight),
std::move(bias),
diff --git a/aten/src/ATen/native/xnnpack/Convolution.h b/aten/src/ATen/native/xnnpack/Convolution.h
index 9cf7c6b..6313856 100644
--- a/aten/src/ATen/native/xnnpack/Convolution.h
+++ b/aten/src/ATen/native/xnnpack/Convolution.h
@@ -20,8 +20,8 @@
std::vector<int64_t> padding,
std::vector<int64_t> dilation,
int64_t groups,
- c10::optional<double> output_min,
- c10::optional<double> output_max);
+ c10::optional<Scalar> output_min,
+ c10::optional<Scalar> output_max);
class Conv2dClampRun final : public torch::OperatorKernel {
public:
diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp
index c4a0aa7..0b4c80a8 100644
--- a/aten/src/ATen/native/xnnpack/Linear.cpp
+++ b/aten/src/ATen/native/xnnpack/Linear.cpp
@@ -150,8 +150,8 @@
c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext(
Tensor weight,
c10::optional<Tensor> bias,
- c10::optional<double> output_min,
- c10::optional<double> output_max) {
+ c10::optional<Scalar> output_min,
+ c10::optional<Scalar> output_max) {
return xnnpack::XNNPackLinearOpContext::create_context(
std::move(weight), std::move(bias), output_min, output_max);
}
diff --git a/aten/src/ATen/native/xnnpack/Linear.h b/aten/src/ATen/native/xnnpack/Linear.h
index 38fe3a2..e0137e6 100644
--- a/aten/src/ATen/native/xnnpack/Linear.h
+++ b/aten/src/ATen/native/xnnpack/Linear.h
@@ -15,8 +15,8 @@
c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext(
Tensor weight,
c10::optional<Tensor> bias,
- c10::optional<double> output_min,
- c10::optional<double> output_max);
+ c10::optional<Scalar> output_min,
+ c10::optional<Scalar> output_max);
class LinearClampRun final : public torch::OperatorKernel {
public:
diff --git a/aten/src/ATen/native/xnnpack/OpContext.cpp b/aten/src/ATen/native/xnnpack/OpContext.cpp
index 02cd767..3ba8c31 100644
--- a/aten/src/ATen/native/xnnpack/OpContext.cpp
+++ b/aten/src/ATen/native/xnnpack/OpContext.cpp
@@ -11,18 +11,20 @@
XNNPackLinearOpContext::create_context(
at::Tensor&& weight,
c10::optional<at::Tensor>&& bias,
- const c10::optional<double> output_min,
- const c10::optional<double> output_max) {
+ const c10::optional<Scalar> output_min,
+ const c10::optional<Scalar> output_max) {
auto linear_op_context =
c10::make_intrusive<XNNPackLinearOpContext>(
std::move(weight),
std::move(bias),
+ output_min,
+ output_max,
xnnpack::internal::linear::create(
weight,
bias,
- output_min ? static_cast<float>(*output_min)
+ output_min ? output_min->to<float>()
: xnnpack::ContextLinear::kMin,
- output_max ? static_cast<float>(*output_max)
+ output_max ? output_max->to<float>()
: xnnpack::ContextLinear::kMax)
);
return linear_op_context;
@@ -39,8 +41,8 @@
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
int64_t groups,
- const c10::optional<double> output_min,
- const c10::optional<double> output_max) {
+ const c10::optional<Scalar> output_min,
+ const c10::optional<Scalar> output_max) {
auto op_context =
xnnpack::internal::convolution2d::create(
weight,
@@ -49,9 +51,9 @@
stride,
dilation,
groups,
- output_min ? static_cast<float>(*output_min)
+ output_min ? output_min->to<float>()
: xnnpack::ContextConv2D::kMin,
- output_max ? static_cast<float>(*output_max)
+ output_max ? output_max->to<float>()
: xnnpack::ContextConv2D::kMax);
auto conv2d_op_context =
c10::make_intrusive<XNNPackConv2dOpContext>(
@@ -61,6 +63,8 @@
std::move(stride),
std::move(dilation),
groups,
+ output_min,
+ output_max,
std::move(op_context));
return conv2d_op_context;
}
diff --git a/aten/src/ATen/native/xnnpack/OpContext.h b/aten/src/ATen/native/xnnpack/OpContext.h
index d151a77..d46d5ee 100644
--- a/aten/src/ATen/native/xnnpack/OpContext.h
+++ b/aten/src/ATen/native/xnnpack/OpContext.h
@@ -13,8 +13,8 @@
using SerializationTypeLinearPrePack = std::tuple<
Tensor,
c10::optional<Tensor>,
- c10::optional<double>,
- c10::optional<double>>;
+ c10::optional<Scalar>,
+ c10::optional<Scalar>>;
using SerializationTypeConv2dPrePack = std::tuple<
Tensor,
c10::optional<Tensor>,
@@ -22,15 +22,15 @@
std::vector<int64_t>,
std::vector<int64_t>,
int64_t,
- c10::optional<double>,
- c10::optional<double>>;
+ c10::optional<Scalar>,
+ c10::optional<Scalar>>;
class LinearOpContext : public torch::jit::CustomClassHolder {
protected:
Tensor orig_weight_;
c10::optional<Tensor> orig_bias_;
- c10::optional<double> output_min_;
- c10::optional<double> output_max_;
+ c10::optional<Scalar> output_min_;
+ c10::optional<Scalar> output_max_;
public:
SerializationTypeLinearPrePack unpack() {
@@ -48,10 +48,14 @@
XNNPackLinearOpContext(
Tensor&& weight,
c10::optional<Tensor>&& bias,
+ c10::optional<Scalar> min,
+ c10::optional<Scalar> max,
ContextLinear&& op_context)
: op_context_(std::move(op_context)) {
orig_weight_ = std::move(weight);
orig_bias_ = std::move(bias);
+ output_min_ = min;
+ output_max_ = max;
}
Tensor run(const Tensor& input);
@@ -59,8 +63,8 @@
static c10::intrusive_ptr<LinearOpContext> create_context(
Tensor&& weight,
c10::optional<Tensor>&& bias,
- const c10::optional<double> output_min,
- const c10::optional<double> output_max);
+ const c10::optional<Scalar> output_min,
+ const c10::optional<Scalar> output_max);
};
class Conv2dOpContext : public torch::jit::CustomClassHolder {
@@ -71,8 +75,8 @@
std::vector<int64_t> padding_;
std::vector<int64_t> dilation_;
int64_t groups_;
- c10::optional<double> output_min_;
- c10::optional<double> output_max_;
+ c10::optional<Scalar> output_min_;
+ c10::optional<Scalar> output_max_;
public:
SerializationTypeConv2dPrePack unpack() {
@@ -102,6 +106,8 @@
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
uint64_t groups,
+ c10::optional<Scalar> min,
+ c10::optional<Scalar> max,
ContextConv2D&& op_context)
: op_context_(std::move(op_context)) {
orig_weight_ = std::move(weight);
@@ -110,6 +116,8 @@
stride_ = std::move(stride);
dilation_ = std::move(dilation);
groups_ = groups;
+ output_min_ = min;
+ output_max_ = max;
}
Tensor run(const Tensor& input);
@@ -121,8 +129,8 @@
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
int64_t groups,
- const c10::optional<double> output_min,
- const c10::optional<double> output_max);
+ const c10::optional<Scalar> output_min,
+ const c10::optional<Scalar> output_max);
};
} // namespace xnnpack
diff --git a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp
index edeea77..cd9beac 100644
--- a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp
+++ b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp
@@ -66,7 +66,7 @@
// We can refactor the code and use a better namespace.
torch::RegisterOperators()
.op("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, "
- "float? output_min=None, float? output_max=None) "
+ "Scalar? output_min=None, Scalar? output_max=None) "
"-> __torch__.torch.classes.xnnpack.LinearOpContext",
torch::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
@@ -81,7 +81,7 @@
DispatchKey::CPUTensorId))
.op("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, "
"int[2] padding, int[2] dilation, int groups, "
- "float? output_min=None, float? output_max=None) "
+ "Scalar? output_min=None, Scalar? output_max=None) "
"-> __torch__.torch.classes.xnnpack.Conv2dOpContext",
torch::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)