[NOOP][clangformat][codemod] Enable CLANGFORMAT for some folders in caffe2/* (#67746)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67746
Test Plan: Visual inspection. Sandcastle.
Reviewed By: zertosh
Differential Revision: D31986646
fbshipit-source-id: 91885c20c3cead3853c49abb9fe0a94a67f33cc8
diff --git a/android/pytorch_android/generate_test_asset.cpp b/android/pytorch_android/generate_test_asset.cpp
index de07464..6105a09 100644
--- a/android/pytorch_android/generate_test_asset.cpp
+++ b/android/pytorch_android/generate_test_asset.cpp
@@ -1,9 +1,9 @@
+#include <torch/csrc/jit/api/module.h>
#include <torch/jit.h>
#include <torch/script.h>
-#include <torch/csrc/jit/api/module.h>
-#include <iostream>
#include <fstream>
+#include <iostream>
#include <string>
int main(int argc, char* argv[]) {
diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
index a0a8ec3..4592537 100644
--- a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
+++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp
@@ -614,8 +614,8 @@
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
- c10::impl::GenericDict dict{c10::StringType::get(),
- c10::unshapedType(firstEntryValue.type())};
+ c10::impl::GenericDict dict{
+ c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert(it->first->toStdString(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
@@ -637,8 +637,8 @@
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
- c10::impl::GenericDict dict{c10::IntType::get(),
- c10::unshapedType(firstEntryValue.type())};
+ c10::impl::GenericDict dict{
+ c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert((int64_t)it->first->longValue(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
diff --git a/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp b/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp
index 0111184..41fc7d5 100644
--- a/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp
+++ b/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp
@@ -105,7 +105,8 @@
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
- uvIdx = (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
+ uvIdx =
+ (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
@@ -131,7 +132,8 @@
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
- uvIdx = (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
+ uvIdx =
+ (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
yi = yData[yIdx];
@@ -152,7 +154,7 @@
}
} else {
jclass Exception = jniEnv->FindClass("java/lang/IllegalArgumentException");
- jniEnv->ThrowNew(Exception,"Illegal memory format code");
+ jniEnv->ThrowNew(Exception, "Illegal memory format code");
}
}
} // namespace pytorch_vision_jni
diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp
index f2b3975..208b333 100644
--- a/benchmarks/cpp/tensorexpr/bench_approx.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp
@@ -1,11 +1,11 @@
#include <benchmark/benchmark.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
+#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
-#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/torch.h>
-#include "caffe2/operators/tanh_op.h"
#include "caffe2/operators/logit_op.h"
+#include "caffe2/operators/tanh_op.h"
using namespace torch::jit;
using namespace torch::jit::tensorexpr;
@@ -32,7 +32,7 @@
auto N = VarHandle("N", kInt);
BufHandle A("A", {N}, kFloat);
auto clamp = 0;
- torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i){
+ torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
auto A_elem = [&]() {
auto elem = A.load(i);
auto min = FloatImm::make(clamp);
@@ -55,20 +55,19 @@
auto B_ref = at::relu(A_t);
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
TORCH_CHECK(at::allclose(B_t, B_ref));
- for (auto _ : state){
+ for (auto _ : state) {
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void log_nnc_sleef(benchmark::State& state) {
auto N = VarHandle("N", kInt);
BufHandle A("A", {N}, kFloat);
torch::jit::tensorexpr::Tensor B =
- Compute("B", {N}, [&](const VarHandle& i) {
- return log(A.load(i));
- });
+ Compute("B", {N}, [&](const VarHandle& i) { return log(A.load(i)); });
LoopNest ln({B});
ln.prepareForCodegen();
vectorize(&ln, B, 8);
@@ -88,16 +87,15 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void log_nnc_fast(benchmark::State& state) {
auto N = VarHandle("N", kInt);
BufHandle A("A", {N}, kFloat);
- torch::jit::tensorexpr::Tensor B =
- Compute("B", {N}, [&](const VarHandle& i) {
- return fast_log(A.load(i));
- });
+ torch::jit::tensorexpr::Tensor B = Compute(
+ "B", {N}, [&](const VarHandle& i) { return fast_log(A.load(i)); });
LoopNest ln({B});
optimizePointwise(&ln, B);
ln.prepareForCodegen();
@@ -117,16 +115,15 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void log_nnc_vml(benchmark::State& state) {
auto N = VarHandle("N", kInt);
BufHandle A("A", {N}, kFloat);
torch::jit::tensorexpr::Tensor B =
- Compute("B", {N}, [&](const VarHandle& i) {
- return log_vml(A.load(i));
- });
+ Compute("B", {N}, [&](const VarHandle& i) { return log_vml(A.load(i)); });
LoopNest ln({B});
vectorize(&ln, B, 8);
ln.prepareForCodegen();
@@ -146,7 +143,8 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["log/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void log_aten(benchmark::State& state) {
@@ -156,7 +154,8 @@
at::log_out(B_t, A_t);
}
state.counters["log/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void logit_nnc_sleef(benchmark::State& state) {
@@ -192,7 +191,8 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["logit/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void logit_nnc_fast(benchmark::State& state) {
@@ -228,7 +228,8 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["logit/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void logit_nnc_vml(benchmark::State& state) {
@@ -264,7 +265,8 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["logit/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void logit_aten(benchmark::State& state) {
@@ -275,7 +277,8 @@
at::native::logit_out(A_t, clamp, B_t);
}
state.counters["logit/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
template <typename T>
@@ -305,16 +308,15 @@
}
state.counters["logit/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void tanh_nnc_fast(benchmark::State& state) {
auto N = VarHandle("N", kInt);
BufHandle A("A", {N}, kFloat);
- torch::jit::tensorexpr::Tensor B =
- Compute("B", {N}, [&](const VarHandle& i) {
- return fast_tanh(A.load(i));
- });
+ torch::jit::tensorexpr::Tensor B = Compute(
+ "B", {N}, [&](const VarHandle& i) { return fast_tanh(A.load(i)); });
LoopNest ln({B});
optimizePointwise(&ln, B);
ln.prepareForCodegen();
@@ -334,7 +336,8 @@
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
}
state.counters["tanh/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void tanh_aten(benchmark::State& state) {
@@ -344,7 +347,8 @@
at::tanh_out(A_t, B_t);
}
state.counters["tanh/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
static void tanh_caffe2(benchmark::State& state) {
@@ -365,71 +369,63 @@
tanh(N, X, Y, &c);
}
state.counters["tanh/s"] = benchmark::Counter(
- uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate);
+ uint64_t(state.range(0) * state.iterations()),
+ benchmark::Counter::kIsRate);
}
-BENCHMARK(relu_nnc)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+BENCHMARK(relu_nnc)->Args({2 << 5})->Args({2 << 8})->Args({2 << 12})->Args(
+ {2 << 14});
BENCHMARK(log_nnc_sleef)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(log_nnc_fast)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(log_nnc_vml)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
-BENCHMARK(log_aten)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
+BENCHMARK(log_aten)->Args({2 << 5})->Args({2 << 8})->Args({2 << 12})->Args(
+ {2 << 14});
BENCHMARK(logit_nnc_sleef)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(logit_nnc_fast)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(logit_nnc_vml)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(logit_aten)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(logit_caffe2)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
BENCHMARK(tanh_nnc_fast)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
-BENCHMARK(tanh_aten)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
+BENCHMARK(tanh_aten)->Args({2 << 5})->Args({2 << 8})->Args({2 << 12})->Args(
+ {2 << 14});
BENCHMARK(tanh_caffe2)
- ->Args({2<<5})
- ->Args({2<<8})
- ->Args({2<<12})
- ->Args({2<<14});
+ ->Args({2 << 5})
+ ->Args({2 << 8})
+ ->Args({2 << 12})
+ ->Args({2 << 14});
diff --git a/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp b/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp
index 4753ca9..eddac0a 100644
--- a/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp
@@ -74,7 +74,6 @@
}
BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) {
-
BufHandle input("input", {N_, C_, H_, W_}, kFloat);
BufHandle weight("weight", {C_}, kFloat);
BufHandle bias("bias", {C_}, kFloat);
@@ -136,7 +135,6 @@
}
BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) {
-
BufHandle input("input", {N_, C_, H_, W_}, kFloat);
BufHandle weight("weight", {C_}, kFloat);
BufHandle bias("bias", {C_}, kFloat);
diff --git a/benchmarks/cpp/tensorexpr/bench_compile.cpp b/benchmarks/cpp/tensorexpr/bench_compile.cpp
index 7856c1d..13a02ee 100644
--- a/benchmarks/cpp/tensorexpr/bench_compile.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_compile.cpp
@@ -1,8 +1,8 @@
#include <benchmark/benchmark.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
+#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
-#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#ifdef TORCH_ENABLE_LLVM
namespace te = torch::jit::tensorexpr;
@@ -12,21 +12,26 @@
constexpr int N = 512;
te::VarHandle n("n", te::kInt);
te::BufHandle A("A", {N}, te::kFloat);
- te::Tensor relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
- return te::Max::make(A.load(i), 0.f, false);
- });
- te::Tensor min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
- return te::Min::make(relu.load(i), 6.f, false);
- });
- te::Tensor plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
- return min6.load(i) + 3.f;
- });
- te::Tensor times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
- return A.load(i) * plus3.load(i);
- });
- te::Tensor sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
- return times.load(i) * 1.f / 6.f;
- });
+ te::Tensor relu =
+ te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return te::Max::make(A.load(i), 0.f, false);
+ });
+ te::Tensor min6 =
+ te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return te::Min::make(relu.load(i), 6.f, false);
+ });
+ te::Tensor plus3 =
+ te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return min6.load(i) + 3.f;
+ });
+ te::Tensor times =
+ te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return A.load(i) * plus3.load(i);
+ });
+ te::Tensor sixth =
+ te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return times.load(i) * 1.f / 6.f;
+ });
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
for (auto tensor : {relu, min6, plus3, times}) {
nest.computeInline(tensor.buf());
@@ -41,21 +46,26 @@
constexpr int N = 512;
te::VarHandle n("n", te::kInt);
te::BufHandle A("A", {N}, te::kFloat);
- te::Tensor relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
- return te::Max::make(A.load(i), 0.f, false);
- });
- te::Tensor min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
- return te::Min::make(relu.load(i), 6.f, false);
- });
- te::Tensor plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
- return min6.load(i) + 3.f;
- });
- te::Tensor times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
- return A.load(i) * plus3.load(i);
- });
- te::Tensor sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
- return times.load(i) * 1.f / 6.f;
- });
+ te::Tensor relu =
+ te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return te::Max::make(A.load(i), 0.f, false);
+ });
+ te::Tensor min6 =
+ te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return te::Min::make(relu.load(i), 6.f, false);
+ });
+ te::Tensor plus3 =
+ te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return min6.load(i) + 3.f;
+ });
+ te::Tensor times =
+ te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return A.load(i) * plus3.load(i);
+ });
+ te::Tensor sixth =
+ te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
+ return times.load(i) * 1.f / 6.f;
+ });
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
for (auto tensor : {relu, min6, plus3, times}) {
nest.computeInline(tensor.buf());
diff --git a/benchmarks/cpp/tensorexpr/bench_concat.cpp b/benchmarks/cpp/tensorexpr/bench_concat.cpp
index 70bfb42..d89856d 100644
--- a/benchmarks/cpp/tensorexpr/bench_concat.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_concat.cpp
@@ -47,7 +47,6 @@
}
void runNNC(benchmark::State& state) {
-
size_t num_inputs = inputs_.size();
size_t num_dims = 2;
@@ -100,7 +99,6 @@
}
void runNNCLoop(benchmark::State& state) {
-
size_t num_inputs = inputs_.size();
size_t num_dims = 2;
diff --git a/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp b/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp
index 5a31312..8ef530b 100644
--- a/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp
@@ -1,7 +1,7 @@
#include <benchmark/benchmark.h>
+#include <c10/core/InferenceMode.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/torch.h>
-#include <c10/core/InferenceMode.h>
using namespace torch::jit;
diff --git a/benchmarks/cpp/tensorexpr/bench_gemm.cpp b/benchmarks/cpp/tensorexpr/bench_gemm.cpp
index 568d40d..a860c10 100644
--- a/benchmarks/cpp/tensorexpr/bench_gemm.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_gemm.cpp
@@ -31,7 +31,7 @@
at::Tensor B;
at::Tensor C;
};
-}
+} // namespace
BENCHMARK_DEFINE_F(Gemm, Torch)(benchmark::State& state) {
for (auto _ : state) {
@@ -40,7 +40,6 @@
}
BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) {
-
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
@@ -63,7 +62,6 @@
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) {
-
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
@@ -122,7 +120,6 @@
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) {
-
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
@@ -181,7 +178,6 @@
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
-
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
@@ -248,7 +244,6 @@
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) {
-
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
diff --git a/benchmarks/cpp/tensorexpr/bench_parallel.cpp b/benchmarks/cpp/tensorexpr/bench_parallel.cpp
index 8f98c98..23bc869 100644
--- a/benchmarks/cpp/tensorexpr/bench_parallel.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_parallel.cpp
@@ -24,8 +24,8 @@
}
void TearDown(benchmark::State& state) override {
- state.counters["tasks"] = benchmark::Counter(uint64_t(state.iterations()) * M,
- benchmark::Counter::kIsRate);
+ state.counters["tasks"] = benchmark::Counter(
+ uint64_t(state.iterations()) * M, benchmark::Counter::kIsRate);
}
int M;
@@ -37,10 +37,9 @@
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
BufHandle a_buf("a", {M}, kFloat);
BufHandle b_buf("b", {M}, kFloat);
- Tensor c_tensor = Compute(
- "c", {{M, "m"}}, [&](const VarHandle& m) {
- return a_buf.load(m) + b_buf.load(m);
- });
+ Tensor c_tensor = Compute("c", {{M, "m"}}, [&](const VarHandle& m) {
+ return a_buf.load(m) + b_buf.load(m);
+ });
LoopNest loop_nest({c_tensor});
auto const& loops = loop_nest.getLoopStmtsFor(c_tensor);
ForPtr m = loops[0];
diff --git a/benchmarks/cpp/tensorexpr/bench_reduce.cpp b/benchmarks/cpp/tensorexpr/bench_reduce.cpp
index 0db6753..fec8c89 100644
--- a/benchmarks/cpp/tensorexpr/bench_reduce.cpp
+++ b/benchmarks/cpp/tensorexpr/bench_reduce.cpp
@@ -1,10 +1,10 @@
#include <benchmark/benchmark.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
-#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
-#include <torch/csrc/jit/tensorexpr/tensor.h>
+#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
+#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/torch.h>
#include <immintrin.h>
@@ -25,8 +25,9 @@
void TearDown(benchmark::State& state) override {
TORCH_CHECK(at::allclose(B, ref, std::sqrt(A.numel()) * 1e-7));
- state.counters["BYTES"] = benchmark::Counter(uint64_t(state.iterations()) * M * sizeof(float),
- benchmark::Counter::kIsRate);
+ state.counters["BYTES"] = benchmark::Counter(
+ uint64_t(state.iterations()) * M * sizeof(float),
+ benchmark::Counter::kIsRate);
}
int M;
@@ -35,7 +36,7 @@
at::Tensor ref;
};
-} // namespace
+} // namespace
BENCHMARK_DEFINE_F(Reduce1D, Torch)(benchmark::State& state) {
for (auto _ : state) {
@@ -48,18 +49,23 @@
#define VALIDATE(F, A, B) ValidateFunc((F), #F, (A), (B))
template <typename Func>
-void ValidateFunc(Func func, const std::string& func_name, at::Tensor& A, at::Tensor& B) {
+void ValidateFunc(
+ Func func,
+ const std::string& func_name,
+ at::Tensor& A,
+ at::Tensor& B) {
func(A, B);
- float *pB = B.data_ptr<float>();
+ float* pB = B.data_ptr<float>();
at::Tensor B2 = torch::sum(A, {0});
- float *pB2 = B2.data_ptr<float>();
+ float* pB2 = B2.data_ptr<float>();
int size = A.numel();
float size_sqrt = std::sqrt(size);
float natural_noise = size_sqrt * 1e-7;
if (!torch::allclose(B, B2, natural_noise)) {
std::ostringstream oss;
oss << func_name << " failed check: " << std::endl;
- oss << "value: " << B << std::endl;;
+ oss << "value: " << B << std::endl;
+ ;
oss << "reference: " << B2 << std::endl;
oss << "threshold: " << natural_noise << std::endl;
throw std::runtime_error(oss.str());
@@ -67,8 +73,8 @@
}
static void reduce1d_naive(at::Tensor& A, at::Tensor& B) {
- float *pA = A.data_ptr<float>();
- float *pB = B.data_ptr<float>();
+ float* pA = A.data_ptr<float>();
+ float* pB = B.data_ptr<float>();
int size = A.numel();
TORCH_CHECK(B.numel() == 1);
*pB = 0.;
@@ -87,8 +93,8 @@
BENCHMARK_REGISTER_F(Reduce1D, Naive)->Args({1 << 24});
static void reduce1d_native_rfactor(at::Tensor& A, at::Tensor& B) {
- float *pA = A.data_ptr<float>();
- float *pB = B.data_ptr<float>();
+ float* pA = A.data_ptr<float>();
+ float* pB = B.data_ptr<float>();
int size = A.numel();
constexpr int kChunkSize = 16;
TORCH_CHECK(B.numel() == 1);
@@ -146,8 +152,8 @@
}
static void reduce1d_native_vector(at::Tensor& A, at::Tensor& B) {
- float *pA = A.data_ptr<float>();
- float *pB = B.data_ptr<float>();
+ float* pA = A.data_ptr<float>();
+ float* pB = B.data_ptr<float>();
int size = A.numel();
constexpr int kChunkSize = sizeof(__m256) / sizeof(float);
TORCH_CHECK(B.numel() == 1);
@@ -177,12 +183,18 @@
static void reduce1d_native_tiled(at::Tensor& A, at::Tensor& B) {
static constexpr int kTileSize = 4;
- float *pA = A.data_ptr<float>();
- float *pB = B.data_ptr<float>();
+ float* pA = A.data_ptr<float>();
+ float* pB = B.data_ptr<float>();
int size = A.numel();
constexpr int kChunkSize = sizeof(__m256) / sizeof(float);
TORCH_CHECK(B.numel() == 1, "Invalid size: ", B.numel(), " != 1");
- TORCH_CHECK(size % kChunkSize == 0, "Invalid size: ", size, " % ", kChunkSize , " ! = 0");
+ TORCH_CHECK(
+ size % kChunkSize == 0,
+ "Invalid size: ",
+ size,
+ " % ",
+ kChunkSize,
+ " ! = 0");
__m256 t[kTileSize];
for (int j = 0; j < kTileSize; j++) {
t[j] = _mm256_setzero_ps();
@@ -190,9 +202,9 @@
int tile_count = size / kChunkSize / kTileSize;
for (int i = 0; i < tile_count; i++) {
- #pragma unroll
+#pragma unroll
for (int j = 0; j < kTileSize; j++) {
- float *p = pA + (i * kTileSize + j) * kChunkSize;
+ float* p = pA + (i * kTileSize + j) * kChunkSize;
__m256 data = _mm256_loadu_ps(p);
t[j] = _mm256_add_ps(t[j], data);
}
@@ -217,7 +229,6 @@
#endif // USE_AVX2
BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) {
-
int M = A.numel();
te::BufHandle AP("A", {M}, te::kFloat);
@@ -249,7 +260,6 @@
BENCHMARK_REGISTER_F(Reduce1D, TeNaive)->Args({1 << 24});
BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) {
-
int M = A.numel();
te::BufHandle AP("A", {M}, te::kFloat);
@@ -289,7 +299,6 @@
BENCHMARK_REGISTER_F(Reduce1D, TeSplitTail)->Args({1 << 24});
BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) {
-
int M = A.numel();
te::BufHandle AP("A", {M}, te::kFloat);
@@ -329,7 +338,6 @@
BENCHMARK_REGISTER_F(Reduce1D, TeSplitMask)->Args({1 << 24});
BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) {
-
int M = A.numel();
const int kChunkSize = 8;
TORCH_CHECK(M % kChunkSize == 0);
@@ -339,9 +347,7 @@
"reduce_full",
{},
te::Sum(),
- [&](const te::ExprHandle& m) {
- return AP.load(m);
- },
+ [&](const te::ExprHandle& m) { return AP.load(m); },
{{M, "M"}});
te::LoopNest loop({BT});
@@ -424,8 +430,9 @@
void TearDown(benchmark::State& state) override {
TORCH_CHECK(at::allclose(B, ref, std::sqrt(A.numel()) * 1e-5));
- state.counters["BYTES"] = benchmark::Counter(uint64_t(state.iterations()) * (A.nbytes() + B.nbytes()),
- benchmark::Counter::kIsRate);
+ state.counters["BYTES"] = benchmark::Counter(
+ uint64_t(state.iterations()) * (A.nbytes() + B.nbytes()),
+ benchmark::Counter::kIsRate);
}
int M;
@@ -441,9 +448,9 @@
}
}
BENCHMARK_REGISTER_F(Reduce2DCol, Torch)
-->Args({1 << 3, 1 << 21})
-->Args({1 << 6, 1 << 18})
-->Args({1 << 12, 1 << 12});
+ ->Args({1 << 3, 1 << 21})
+ ->Args({1 << 6, 1 << 18})
+ ->Args({1 << 12, 1 << 12});
BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) {
constexpr int kCacheSize = 1 << 12;
@@ -476,15 +483,16 @@
cg.call({A.data_ptr<float>(), B.data_ptr<float>()});
}
}
-BENCHMARK_REGISTER_F(Reduce2DCol, OpSchedule)->Apply(//CustomArgs);
- [](benchmark::internal::Benchmark* b) {
- for (auto sch : {0, 1, 2, 3}) {
- for (auto rows : {3, 6, 12}) {
- auto cols = 24 - rows;
- b->Args({1 << rows, 1 << cols, sch});
- }
- }
- });
+BENCHMARK_REGISTER_F(Reduce2DCol, OpSchedule)
+ ->Apply( // CustomArgs);
+ [](benchmark::internal::Benchmark* b) {
+ for (auto sch : {0, 1, 2, 3}) {
+ for (auto rows : {3, 6, 12}) {
+ auto cols = 24 - rows;
+ b->Args({1 << rows, 1 << cols, sch});
+ }
+ }
+ });
class Reduce2DRow : public benchmark::Fixture {
public:
@@ -500,8 +508,9 @@
void TearDown(benchmark::State& state) override {
TORCH_CHECK(at::allclose(B, ref, std::sqrt(A.numel()) * 1e-4));
- state.counters["BYTES"] = benchmark::Counter(uint64_t(state.iterations()) * (A.nbytes() + B.nbytes()),
- benchmark::Counter::kIsRate);
+ state.counters["BYTES"] = benchmark::Counter(
+ uint64_t(state.iterations()) * (A.nbytes() + B.nbytes()),
+ benchmark::Counter::kIsRate);
}
int M;
@@ -517,10 +526,10 @@
}
}
BENCHMARK_REGISTER_F(Reduce2DRow, Torch)
-->Args({1 << 3, 1 << 21})
-->Args({1 << 6, 1 << 18})
-->Args({1 << 12, 1 << 12})
-->Args({1 << 18, 1 << 6});
+ ->Args({1 << 3, 1 << 21})
+ ->Args({1 << 6, 1 << 18})
+ ->Args({1 << 12, 1 << 12})
+ ->Args({1 << 18, 1 << 6});
BENCHMARK_DEFINE_F(Reduce2DRow, Hand)(benchmark::State& state) {
auto a = A.data_ptr<float>();
@@ -533,7 +542,8 @@
for (int n_outer = 0; n_outer < N; n_outer += Nb) {
for (int m_inner = 0; m_inner < Mb; m_inner++) {
for (int n_inner = 0; n_inner < Nb; n_inner++) {
- bregs[m_inner][n_inner] += a[(m_outer + m_inner) * N + n_outer + n_inner];
+ bregs[m_inner][n_inner] +=
+ a[(m_outer + m_inner) * N + n_outer + n_inner];
}
}
}
@@ -549,13 +559,13 @@
fn();
}
}
-BENCHMARK_REGISTER_F(Reduce2DRow, Hand)
-->Args({1 << 18, 1 << 6});
+BENCHMARK_REGISTER_F(Reduce2DRow, Hand)->Args({1 << 18, 1 << 6});
BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) {
constexpr int kChunkSize = 8;
te::BufHandle a("A", {M, N}, te::kFloat);
- te::Tensor b = te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
+ te::Tensor b =
+ te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
te::LoopNest nest({b});
auto sch = state.range(2);
@@ -598,12 +608,13 @@
cg.call({A.data_ptr<float>(), B.data_ptr<float>()});
}
}
-BENCHMARK_REGISTER_F(Reduce2DRow, OpSchedule)->Apply(//CustomArgs);
- [](benchmark::internal::Benchmark* b) {
- for (auto sch : {0, 1, 2, 3}) {
- for (auto rows : {3, 6, 12, 18}) {
- auto cols = 24 - rows;
- b->Args({1 << rows, 1 << cols, sch});
- }
- }
- });
+BENCHMARK_REGISTER_F(Reduce2DRow, OpSchedule)
+ ->Apply( // CustomArgs);
+ [](benchmark::internal::Benchmark* b) {
+ for (auto sch : {0, 1, 2, 3}) {
+ for (auto rows : {3, 6, 12, 18}) {
+ auto cols = 24 - rows;
+ b->Args({1 << rows, 1 << cols, sch});
+ }
+ }
+ });
diff --git a/benchmarks/operator_benchmark/pt_extension/extension.cpp b/benchmarks/operator_benchmark/pt_extension/extension.cpp
index 22a4527..3fda851 100644
--- a/benchmarks/operator_benchmark/pt_extension/extension.cpp
+++ b/benchmarks/operator_benchmark/pt_extension/extension.cpp
@@ -18,8 +18,8 @@
// in a loop and report the execution time. This diff resolves that issue by
// registering this consume op with correct alias information which is DEFAULT.
TORCH_LIBRARY_FRAGMENT(operator_benchmark, m) {
- m.def("_consume", &consume);
- m.def("_consume.list", &consume_list);
+ m.def("_consume", &consume);
+ m.def("_consume.list", &consume_list);
}
PYBIND11_MODULE(benchmark_cpp_extension, m) {
diff --git a/benchmarks/static_runtime/deep_wide_pt.cc b/benchmarks/static_runtime/deep_wide_pt.cc
index b9a6c42..6699b39 100644
--- a/benchmarks/static_runtime/deep_wide_pt.cc
+++ b/benchmarks/static_runtime/deep_wide_pt.cc
@@ -56,7 +56,6 @@
return torch.leaky_relu(x, neg_slope)
)JIT";
-
void import_libs(
std::shared_ptr<at::CompilationUnit> cu,
const std::string& class_name,
@@ -65,9 +64,8 @@
torch::jit::SourceImporter si(
cu,
&tensor_table,
- [&](const std::string& /* unused */) -> std::shared_ptr<torch::jit::Source> {
- return src;
- },
+ [&](const std::string& /* unused */)
+ -> std::shared_ptr<torch::jit::Source> { return src; },
/*version=*/2);
si.loadType(c10::QualifiedName(class_name));
}
diff --git a/benchmarks/static_runtime/deep_wide_pt.h b/benchmarks/static_runtime/deep_wide_pt.h
index 616d497..73a9431 100644
--- a/benchmarks/static_runtime/deep_wide_pt.h
+++ b/benchmarks/static_runtime/deep_wide_pt.h
@@ -103,8 +103,7 @@
}
// Potential optimization: call MKLDNN directly.
- at::cpu::bmm_out(
- ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]);
+ at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]);
if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) {
// in unlikely case that the input tensor changed we need to
diff --git a/benchmarks/static_runtime/deep_wide_pt_bench.cc b/benchmarks/static_runtime/deep_wide_pt_bench.cc
index 08bdeaf..df8d2d1 100644
--- a/benchmarks/static_runtime/deep_wide_pt_bench.cc
+++ b/benchmarks/static_runtime/deep_wide_pt_bench.cc
@@ -91,7 +91,8 @@
}
std::shared_ptr<torch::jit::StaticModule> getStaticModule() {
- static auto smod = std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
+ static auto smod =
+ std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
return smod;
}
@@ -193,17 +194,16 @@
BENCHMARK(BM_deep_wide_static_threaded)->Threads(8);
BENCHMARK(BM_long_static_memory_optimization)
- ->Args({2<<0, 0})
- ->Args({2<<2, 0})
- ->Args({2<<4, 0})
- ->Args({2<<8, 0})
- ->Args({2<<0, 1})
- ->Args({2<<2, 1})
- ->Args({2<<4, 1})
- ->Args({2<<8, 1});
+ ->Args({2 << 0, 0})
+ ->Args({2 << 2, 0})
+ ->Args({2 << 4, 0})
+ ->Args({2 << 8, 0})
+ ->Args({2 << 0, 1})
+ ->Args({2 << 2, 1})
+ ->Args({2 << 4, 1})
+ ->Args({2 << 8, 1});
-int main(int argc, char** argv)
-{
+int main(int argc, char** argv) {
c10::ParseCommandLineFlags(&argc, &argv);
::benchmark::Initialize(&argc, argv);
::benchmark::RunSpecifiedBenchmarks();
diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc
index fbb0087..895972b 100644
--- a/benchmarks/static_runtime/test_static_module.cc
+++ b/benchmarks/static_runtime/test_static_module.cc
@@ -39,12 +39,14 @@
std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
const auto& value_group = sm.value_group();
- std::vector<const Value*> expected_input_aliases{graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
+ std::vector<const Value*> expected_input_aliases{
+ graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
for (auto* value : expected_input_aliases) {
EXPECT_TRUE(value_group.isExternalAlias(value));
}
- std::vector<const Value*> expected_output_aliases{graph.outputs()[0], nodes[2]->output()};
+ std::vector<const Value*> expected_output_aliases{
+ graph.outputs()[0], nodes[2]->output()};
for (auto* value : expected_output_aliases) {
EXPECT_TRUE(value_group.isOutputAlias(value));
}
diff --git a/benchmarks/static_runtime/test_utils.cc b/benchmarks/static_runtime/test_utils.cc
index eb54f95..9caab76 100644
--- a/benchmarks/static_runtime/test_utils.cc
+++ b/benchmarks/static_runtime/test_utils.cc
@@ -175,10 +175,10 @@
} // namespace
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
- auto graph = std::make_shared<Graph>();
- std::unordered_map<std::string, Value*> vmap;
- parseIR(ir, graph.get(), vmap);
- return graph;
+ auto graph = std::make_shared<Graph>();
+ std::unordered_map<std::string, Value*> vmap;
+ parseIR(ir, graph.get(), vmap);
+ return graph;
}
void testStaticRuntime(
@@ -206,19 +206,18 @@
continue;
}
StaticModuleOptions opts{
- .cleanup_activations = true,
- .enable_out_variant = enable_out_variant,
- .optimize_memory = enable_out_variant,
- .manage_output_tensors = manage_output_tensors
- };
+ .cleanup_activations = true,
+ .enable_out_variant = enable_out_variant,
+ .optimize_memory = enable_out_variant,
+ .manage_output_tensors = manage_output_tensors};
auto smodule = test_context->makeStaticModule(opts);
StaticRuntime runtime(smodule);
auto actual = runtime(args, {});
if (actual.isTensor()) {
EXPECT_GE(smodule.nodes().size(), 2)
- << "If we only have one node, the output of the op we are testing is "
- << "not being managed by the memory planner! A failure here "
- << "can typically be fixed by clone()ing the output of the test script.";
+ << "If we only have one node, the output of the op we are testing is "
+ << "not being managed by the memory planner! A failure here "
+ << "can typically be fixed by clone()ing the output of the test script.";
}
runtime.check_for_memory_leak();
// first run
@@ -239,7 +238,8 @@
runtime.deallocateOutputTensors();
runtime.checkOutputTensorMemoryLeaks();
}
- // Run static runtime again with an input of the shape observed during the profile run.
+ // Run static runtime again with an input of the shape observed during
+ // the profile run.
expect = test_context->getExpected(args);
actual = runtime(args, {});
runtime.check_for_memory_leak();
diff --git a/benchmarks/static_runtime/test_utils.h b/benchmarks/static_runtime/test_utils.h
index 3b5e1a5..1908ada 100644
--- a/benchmarks/static_runtime/test_utils.h
+++ b/benchmarks/static_runtime/test_utils.h
@@ -31,7 +31,9 @@
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
-bool hasProcessedNodeWithName(torch::jit::StaticModule& smodule, const char *name);
+bool hasProcessedNodeWithName(
+ torch::jit::StaticModule& smodule,
+ const char* name);
} // namespace test
} // namespace jit
diff --git a/ios/TestApp/TestApp/AppDelegate.m b/ios/TestApp/TestApp/AppDelegate.m
index ed6928a..7438a94 100644
--- a/ios/TestApp/TestApp/AppDelegate.m
+++ b/ios/TestApp/TestApp/AppDelegate.m
@@ -6,38 +6,40 @@
@implementation AppDelegate
-
-- (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
- // Override point for customization after application launch.
- return YES;
+- (BOOL)application:(UIApplication *)application
+ didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
+ // Override point for customization after application launch.
+ return YES;
}
-
- (void)applicationWillResignActive:(UIApplication *)application {
- // Sent when the application is about to move from active to inactive state. This can occur for certain types of temporary interruptions (such as an incoming phone call or SMS message) or when the user quits the application and it begins the transition to the background state.
- // Use this method to pause ongoing tasks, disable timers, and invalidate graphics rendering callbacks. Games should use this method to pause the game.
+ // Sent when the application is about to move from active to inactive state. This can occur for
+ // certain types of temporary interruptions (such as an incoming phone call or SMS message) or
+ // when the user quits the application and it begins the transition to the background state. Use
+ // this method to pause ongoing tasks, disable timers, and invalidate graphics rendering
+ // callbacks. Games should use this method to pause the game.
}
-
- (void)applicationDidEnterBackground:(UIApplication *)application {
- // Use this method to release shared resources, save user data, invalidate timers, and store enough application state information to restore your application to its current state in case it is terminated later.
- // If your application supports background execution, this method is called instead of applicationWillTerminate: when the user quits.
+ // Use this method to release shared resources, save user data, invalidate timers, and store
+ // enough application state information to restore your application to its current state in case
+ // it is terminated later. If your application supports background execution, this method is
+ // called instead of applicationWillTerminate: when the user quits.
}
-
- (void)applicationWillEnterForeground:(UIApplication *)application {
- // Called as part of the transition from the background to the active state; here you can undo many of the changes made on entering the background.
+ // Called as part of the transition from the background to the active state; here you can undo
+ // many of the changes made on entering the background.
}
-
- (void)applicationDidBecomeActive:(UIApplication *)application {
- // Restart any tasks that were paused (or not yet started) while the application was inactive. If the application was previously in the background, optionally refresh the user interface.
+ // Restart any tasks that were paused (or not yet started) while the application was inactive. If
+ // the application was previously in the background, optionally refresh the user interface.
}
-
- (void)applicationWillTerminate:(UIApplication *)application {
- // Called when the application is about to terminate. Save data if appropriate. See also applicationDidEnterBackground:.
+ // Called when the application is about to terminate. Save data if appropriate. See also
+ // applicationDidEnterBackground:.
}
-
@end
diff --git a/ios/TestApp/TestApp/main.m b/ios/TestApp/TestApp/main.m
index 81e84cb..1cfa2c6 100644
--- a/ios/TestApp/TestApp/main.m
+++ b/ios/TestApp/TestApp/main.m
@@ -1,8 +1,8 @@
#import <UIKit/UIKit.h>
#import "AppDelegate.h"
-int main(int argc, char * argv[]) {
- @autoreleasepool {
- return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class]));
- }
+int main(int argc, char* argv[]) {
+ @autoreleasepool {
+ return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class]));
+ }
}
diff --git a/torch/custom_class.h b/torch/custom_class.h
index a80f830..a9270a6 100644
--- a/torch/custom_class.h
+++ b/torch/custom_class.h
@@ -1,6 +1,5 @@
#pragma once
-#include <ATen/core/stack.h>
#include <ATen/core/builtin_function.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
@@ -11,8 +10,8 @@
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/TypeTraits.h>
-#include <torch/library.h>
#include <torch/custom_class_detail.h>
+#include <torch/library.h>
#include <iostream>
#include <sstream>
@@ -34,8 +33,7 @@
template <typename Func>
decltype(auto) init(Func&& f) {
- using InitTraits =
- c10::guts::infer_function_traits_t<std::decay_t<Func>>;
+ using InitTraits = c10::guts::infer_function_traits_t<std::decay_t<Func>>;
using ParameterTypeList = typename InitTraits::parameter_types;
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
@@ -62,8 +60,9 @@
/// is registered with a C++ lambda expression.
template <class CurClass>
class class_ : public ::torch::detail::class_base {
- static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
- "torch::class_<T> requires T to inherit from CustomClassHolder");
+ static_assert(
+ std::is_base_of<CustomClassHolder, CurClass>::value,
+ "torch::class_<T> requires T to inherit from CustomClassHolder");
public:
/// This constructor actually registers the class type.
@@ -73,18 +72,27 @@
/// see this class exposed as in Python and TorchScript. For example, if
/// you pass `foo` as the namespace name and `Bar` as the className, the
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
- explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "")
- : class_base(namespaceName, className, std::move(doc_string), typeid(c10::intrusive_ptr<CurClass>), typeid(c10::tagged_capsule<CurClass>)) {}
+ explicit class_(
+ const std::string& namespaceName,
+ const std::string& className,
+ std::string doc_string = "")
+ : class_base(
+ namespaceName,
+ className,
+ std::move(doc_string),
+ typeid(c10::intrusive_ptr<CurClass>),
+ typeid(c10::tagged_capsule<CurClass>)) {}
/// def() can be used in conjunction with `torch::init()` to register
/// a constructor for a given C++ class type. For example, passing
- /// `torch::init<int, std::string>()` would register a two-argument constructor
- /// taking an `int` and a `std::string` as argument.
+ /// `torch::init<int, std::string>()` would register a two-argument
+ /// constructor taking an `int` and a `std::string` as argument.
template <typename... Types>
class_& def(
torch::detail::types<void, Types...>,
std::string doc_string = "",
- std::initializer_list<arg> default_args = {}) { // Used in combination with
+ std::initializer_list<arg> default_args =
+ {}) { // Used in combination with
// torch::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
@@ -247,11 +255,18 @@
return def_property(name, getter_func);
}
- /// This is an unsafe method registration API added for adding custom JIT backend support via custom
- /// C++ classes. It is not for general purpose use.
- class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema, std::string doc_string = "") {
+ /// This is an unsafe method registration API added for adding custom JIT
+ /// backend support via custom C++ classes. It is not for general purpose use.
+ class_& _def_unboxed(
+ std::string name,
+ std::function<void(jit::Stack&)> func,
+ c10::FunctionSchema schema,
+ std::string doc_string = "") {
auto method = std::make_unique<jit::BuiltinOpFunction>(
- qualClassName + "." + name, std::move(schema), std::move(func), std::move(doc_string));
+ qualClassName + "." + name,
+ std::move(schema),
+ std::move(func),
+ std::move(doc_string));
classTypePtr->addMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
@@ -362,7 +377,8 @@
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto qualMethodName = qualClassName + "." + name;
- auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
+ auto schema =
+ c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
// If default values are provided for function arguments, there must be
// none (no default values) or default values for all function
@@ -372,11 +388,11 @@
// have an actual default value provided.
TORCH_CHECK(
default_args.size() == 0 ||
- default_args.size() == schema.arguments().size() - 1,
+ default_args.size() == schema.arguments().size() - 1,
"Default values must be specified for none or all arguments");
- // If there are default args, copy the argument names and default values to the
- // function schema.
+ // If there are default args, copy the argument names and default values to
+ // the function schema.
if (default_args.size() > 0) {
schema = withNewArguments(schema, default_args);
}
@@ -391,7 +407,10 @@
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_unique<jit::BuiltinOpFunction>(
- qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string));
+ qualMethodName,
+ std::move(schema),
+ std::move(wrapped_func),
+ std::move(doc_string));
// Register the method here to keep the Method alive.
// ClassTypes do not hold ownership of their methods (normally it
@@ -404,18 +423,20 @@
}
};
-/// make_custom_class() is a convenient way to create an instance of a registered
-/// custom class and wrap it in an IValue, for example when you want to pass the
-/// object to TorchScript. Its syntax is equivalent to APIs like `std::make_shared<>`
-/// or `c10::make_intrusive<>`.
+/// make_custom_class() is a convenient way to create an instance of a
+/// registered custom class and wrap it in an IValue, for example when you want
+/// to pass the object to TorchScript. Its syntax is equivalent to APIs like
+/// `std::make_shared<>` or `c10::make_intrusive<>`.
///
-/// For example, if you have a custom C++ class that can be constructed from an `int`
-/// and `std::string`, you might use this API like so:
+/// For example, if you have a custom C++ class that can be constructed from an
+/// `int` and `std::string`, you might use this API like so:
///
-/// IValue custom_class_iv = torch::make_custom_class<MyClass>(3, "foobarbaz");
+/// IValue custom_class_iv = torch::make_custom_class<MyClass>(3,
+/// "foobarbaz");
template <typename CurClass, typename... CtorArgs>
c10::IValue make_custom_class(CtorArgs&&... args) {
- auto userClassInstance = c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
+ auto userClassInstance =
+ c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
return c10::IValue(std::move(userClassInstance));
}
@@ -439,19 +460,26 @@
// better reflect that these features are not limited only to TorchScript
namespace jit {
-using ::torch::getCustomClass;
-using ::torch::isCustomClass;
-using ::torch::init;
using ::torch::class_;
+using ::torch::getCustomClass;
+using ::torch::init;
+using ::torch::isCustomClass;
} // namespace jit
template <class CurClass>
inline class_<CurClass> Library::class_(const std::string& className) {
- TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
- "class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
- "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
- "(Error occurred at ", file_, ":", line_, ")");
+ TORCH_CHECK(
+ kind_ == DEF || kind_ == FRAGMENT,
+ "class_(\"",
+ className,
+ "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
+ "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
+ "(Error occurred at ",
+ file_,
+ ":",
+ line_,
+ ")");
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
return torch::class_<CurClass>(*ns_, className);
}
@@ -460,11 +488,18 @@
template <class CurClass>
inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
- auto class_name = std::string(className.operator const char *());
- TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
- "class_(\"", class_name, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
- "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
- "(Error occurred at ", file_, ":", line_, ")");
+ auto class_name = std::string(className.operator const char*());
+ TORCH_CHECK(
+ kind_ == DEF || kind_ == FRAGMENT,
+ "class_(\"",
+ class_name,
+ "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
+ "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
+ "(Error occurred at ",
+ file_,
+ ":",
+ line_,
+ ")");
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
return torch::class_<CurClass>(*ns_, class_name);
}
diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h
index a1c4282..b501053 100644
--- a/torch/custom_class_detail.h
+++ b/torch/custom_class_detail.h
@@ -46,7 +46,8 @@
}
// Explicit constructor.
- explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {}
+ explicit arg(std::string name)
+ : name_(std::move(name)), value_(c10::nullopt) {}
// Assignment operator. This enables the pybind-like syntax of
// torch::arg("name") = value.
arg& operator=(const c10::IValue& rhs) {
@@ -57,8 +58,8 @@
// The name of the argument. This is copied to the schema; argument
// names cannot be extracted from the C++ declaration.
std::string name_;
- // IValue's default constructor makes it None, which is not distinguishable from
- // an actual, user-provided default value that is None. This boolean
+ // IValue's default constructor makes it None, which is not distinguishable
+ // from an actual, user-provided default value that is None. This boolean
// helps distinguish between the two cases.
c10::optional<c10::IValue> value_;
};
@@ -133,13 +134,15 @@
using IValueArgTypes =
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
- // TODO We shouldn't use c10::impl stuff directly here. We should use the KernelFunction API instead.
+ // TODO We shouldn't use c10::impl stuff directly here. We should use the
+ // KernelFunction API instead.
return (functor)(c10::impl::ivalue_to_arg<
typename c10::impl::decay_if_not_tensor<
c10::guts::typelist::
element_t<ivalue_arg_indices, IValueArgTypes>>::type,
- AllowDeprecatedTypes>::call(
- torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args))...);
+ AllowDeprecatedTypes>::
+ call(torch::jit::peek(
+ stack, ivalue_arg_indices, num_ivalue_args))...);
}
template <class Functor, bool AllowDeprecatedTypes>
@@ -180,13 +183,17 @@
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
-inline void checkValidIdent(const std::string& str, const char *type) {
+inline void checkValidIdent(const std::string& str, const char* type) {
for (const auto i : c10::irange(str.size())) {
- TORCH_CHECK(validIdent(i, str[i]),
- type,
- " must be a valid Python/C++ identifier."
- " Character '", str[i], "' at index ",
- i, " is illegal.");
+ TORCH_CHECK(
+ validIdent(i, str[i]),
+ type,
+ " must be a valid Python/C++ identifier."
+ " Character '",
+ str[i],
+ "' at index ",
+ i,
+ " is illegal.");
}
}
@@ -227,6 +234,6 @@
namespace jit {
using ::torch::registerCustomClass;
using ::torch::registerCustomClassMethod;
-}
+} // namespace jit
} // namespace torch
diff --git a/torch/lib/libshm/core.cpp b/torch/lib/libshm/core.cpp
index 322e09f..d033806 100644
--- a/torch/lib/libshm/core.cpp
+++ b/torch/lib/libshm/core.cpp
@@ -4,8 +4,8 @@
#include <unordered_map>
#include <libshm/err.h>
-#include <libshm/socket.h>
#include <libshm/libshm.h>
+#include <libshm/socket.h>
std::unordered_map<std::string, ClientSocket> managers;
std::string manager_executable_path;
@@ -47,7 +47,7 @@
constexpr auto MAX_BUFFER_SIZE = 1000;
std::array<char, MAX_BUFFER_SIZE> buffer;
std::string handle;
- while(handle.empty() || handle.back() != '\n') {
+ while (handle.empty() || handle.back() != '\n') {
const auto bytes_read = read(pipe_ends[0], buffer.data(), buffer.size());
SYSCHECK_ERR_RETURN_NEG1(bytes_read);
if (bytes_read == 0) {
@@ -68,11 +68,11 @@
std::string msg("torch_shm_manager at \"");
msg += manager_executable_path;
msg += "\": ";
- msg += handle.substr(7); // remove "ERROR: "
+ msg += handle.substr(7); // remove "ERROR: "
throw std::runtime_error(msg);
}
- ClientSocket manager {handle};
+ ClientSocket manager{handle};
managers.emplace(std::move(handle), std::move(manager));
}
@@ -87,41 +87,49 @@
}
}
-void libshm_init(const char *manager_exec_path) {
+void libshm_init(const char* manager_exec_path) {
manager_executable_path = std::string(manager_exec_path);
}
-THManagedMapAllocatorInit::THManagedMapAllocatorInit(const char* manager_handle, const char* filename)
- : manager_handle_(manager_handle ? manager_handle : "") {
+THManagedMapAllocatorInit::THManagedMapAllocatorInit(
+ const char* manager_handle,
+ const char* filename)
+ : manager_handle_(manager_handle ? manager_handle : "") {
// TODO: unlock GIL when contacting the manager
try {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- ClientSocket *socket;
+ ClientSocket* socket;
if (!manager_handle_.empty()) {
socket = &get_manager_socket(manager_handle_);
} else {
if (managers.size() == 0) {
start_manager();
}
- const auto &manager = managers.begin();
+ const auto& manager = managers.begin();
manager_handle_ = manager->first;
socket = &manager->second;
}
AllocInfo info = get_alloc_info(filename);
socket->register_allocation(info);
- } catch(std::exception &e) {
+ } catch (std::exception& e) {
TORCH_CHECK(false, e.what());
}
}
-THManagedMapAllocator::THManagedMapAllocator(const char *manager_handle, const char *filename, int flags, ptrdiff_t size)
- : THManagedMapAllocatorInit(manager_handle, filename), at::RefcountedMapAllocator(filename, flags, size) {}
+THManagedMapAllocator::THManagedMapAllocator(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size)
+ : THManagedMapAllocatorInit(manager_handle, filename),
+ at::RefcountedMapAllocator(filename, flags, size) {}
void THManagedMapAllocator::close() {
- if (closed_) return;
+ if (closed_)
+ return;
AllocInfo info = get_alloc_info(filename());
info.free = true;
- ClientSocket &socket = get_manager_socket(manager_handle_);
+ ClientSocket& socket = get_manager_socket(manager_handle_);
at::RefcountedMapAllocator::close();
socket.register_deallocation(info);
}
@@ -130,11 +138,21 @@
delete static_cast<THManagedMapAllocator*>(ptr);
}
-at::DataPtr THManagedMapAllocator::makeDataPtr(const char* manager_handle, const char* filename, int flags, ptrdiff_t size) {
- auto* context = new THManagedMapAllocator(manager_handle, filename, flags, size);
- return {context->data(), context, &deleteTHManagedMapAllocator, at::DeviceType::CPU};
+at::DataPtr THManagedMapAllocator::makeDataPtr(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size) {
+ auto* context =
+ new THManagedMapAllocator(manager_handle, filename, flags, size);
+ return {
+ context->data(),
+ context,
+ &deleteTHManagedMapAllocator,
+ at::DeviceType::CPU};
}
-THManagedMapAllocator* THManagedMapAllocator::fromDataPtr(const at::DataPtr& dptr) {
+THManagedMapAllocator* THManagedMapAllocator::fromDataPtr(
+ const at::DataPtr& dptr) {
return dptr.cast_context<THManagedMapAllocator>(&deleteTHManagedMapAllocator);
}
diff --git a/torch/lib/libshm/err.h b/torch/lib/libshm/err.h
index 5432510..e1e6aa4 100644
--- a/torch/lib/libshm/err.h
+++ b/torch/lib/libshm/err.h
@@ -1,7 +1,7 @@
#pragma once
-#include <system_error>
#include <cerrno>
+#include <system_error>
// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets
// `errno` to `EINVAL` in child process on some macos
@@ -11,15 +11,15 @@
// All functions used in `libshm` (so far) indicate error by returning `-1`. If
// you want to use a function with a different error reporting mechanism, you
// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`.
-#define SYSCHECK_ERR_RETURN_NEG1(expr) \
-while (true) { \
- if ((expr) == -1) { \
- if (errno == EINTR) { \
- continue; \
- } else { \
- throw std::system_error(errno, std::system_category()); \
- } \
- } else { \
- break; \
- } \
-}
+#define SYSCHECK_ERR_RETURN_NEG1(expr) \
+ while (true) { \
+ if ((expr) == -1) { \
+ if (errno == EINTR) { \
+ continue; \
+ } else { \
+ throw std::system_error(errno, std::system_category()); \
+ } \
+ } else { \
+ break; \
+ } \
+ }
diff --git a/torch/lib/libshm/libshm.h b/torch/lib/libshm/libshm.h
index 7bc612f..b289f9a 100644
--- a/torch/lib/libshm/libshm.h
+++ b/torch/lib/libshm/libshm.h
@@ -4,11 +4,11 @@
#ifdef __cplusplus
-void libshm_init(const char *manager_exec_path);
+void libshm_init(const char* manager_exec_path);
// Superclass to run a constructor before at::RefcountedMapAllocator
class THManagedMapAllocatorInit {
-protected:
+ protected:
THManagedMapAllocatorInit(const char* manager_handle, const char* filename);
std::string manager_handle_;
};
@@ -16,18 +16,31 @@
// Like a at::RefcountedMapAllocator, but it also makes use of an external
// shared memory manager process to ensure that shared memory regions actually
// get freed in the end (even if processes lose the memory).
-class THManagedMapAllocator : private THManagedMapAllocatorInit, public at::RefcountedMapAllocator {
-public:
- THManagedMapAllocator(const char* manager_handle, const char* filename, int flags, ptrdiff_t size);
+class THManagedMapAllocator : private THManagedMapAllocatorInit,
+ public at::RefcountedMapAllocator {
+ public:
+ THManagedMapAllocator(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size);
void close() override;
- ~THManagedMapAllocator() { close(); }
+ ~THManagedMapAllocator() {
+ close();
+ }
- static at::DataPtr makeDataPtr(const char* manager_handle, const char* filename, int flags, ptrdiff_t size);
+ static at::DataPtr makeDataPtr(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size);
static THManagedMapAllocator* fromDataPtr(const at::DataPtr&);
- const char* manager_handle() const { return manager_handle_.c_str(); }
+ const char* manager_handle() const {
+ return manager_handle_.c_str();
+ }
};
#endif
diff --git a/torch/lib/libshm/manager.cpp b/torch/lib/libshm/manager.cpp
index 3622920..8bbb5b1 100644
--- a/torch/lib/libshm/manager.cpp
+++ b/torch/lib/libshm/manager.cpp
@@ -1,11 +1,11 @@
-#include <algorithm>
-#include <cerrno>
#include <fcntl.h>
-#include <memory>
#include <poll.h>
-#include <set>
#include <sys/mman.h>
#include <unistd.h>
+#include <algorithm>
+#include <cerrno>
+#include <memory>
+#include <set>
#include <unordered_map>
#include <vector>
@@ -26,19 +26,17 @@
#endif
struct ClientSession {
- ClientSession(ManagerSocket s): socket(std::move(s)), pid(0) {}
+ ClientSession(ManagerSocket s) : socket(std::move(s)), pid(0) {}
ManagerSocket socket;
pid_t pid;
};
-
std::vector<struct pollfd> pollfds;
std::unordered_map<int, ClientSession> client_sessions;
// TODO: check if objects have been freed from time to time
std::set<std::string> used_objects;
-
void register_fd(int fd) {
struct pollfd pfd = {0};
pfd.fd = fd;
@@ -46,17 +44,17 @@
pollfds.push_back(pfd);
}
-
void unregister_fd(int fd) {
pollfds.erase(
- std::remove_if(pollfds.begin(), pollfds.end(),
- [fd](const struct pollfd &pfd) { return pfd.fd == fd; }),
- pollfds.end());
+ std::remove_if(
+ pollfds.begin(),
+ pollfds.end(),
+ [fd](const struct pollfd& pfd) { return pfd.fd == fd; }),
+ pollfds.end());
client_sessions.erase(fd);
}
-
-void print_init_message(const char *message) {
+void print_init_message(const char* message) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t unused;
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
@@ -65,7 +63,7 @@
unused = write(1, "\n", 1);
}
-bool object_exists(const char *name) {
+bool object_exists(const char* name) {
int fd = shm_open(name, O_RDONLY, 0);
if (fd >= 0) {
close(fd);
@@ -75,7 +73,7 @@
}
}
-void free_used_object(const std::string &name) {
+void free_used_object(const std::string& name) {
if (!object_exists(name.c_str())) {
DEBUG("object %s appears to have been freed", name.c_str());
used_objects.erase(name);
@@ -85,14 +83,13 @@
}
// NOLINTNEXTLINE(bugprone-exception-escape)
-int main(int argc, char *argv[]) {
- setsid(); // Daemonize the process
+int main(int argc, char* argv[]) {
+ setsid(); // Daemonize the process
std::unique_ptr<ManagerServerSocket> srv_socket;
c10::optional<c10::TempDir> tempdir;
try {
- tempdir =
- c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-");
+ tempdir = c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-");
if (!tempdir.has_value()) {
throw std::runtime_error(
"could not generate a random directory for manager socket");
@@ -122,17 +119,18 @@
int nevents;
if (client_sessions.size() == 0)
timeout = SHUTDOWN_TIMEOUT;
- SYSCHECK_ERR_RETURN_NEG1(nevents = poll(pollfds.data(), pollfds.size(), timeout));
+ SYSCHECK_ERR_RETURN_NEG1(
+ nevents = poll(pollfds.data(), pollfds.size(), timeout));
timeout = -1;
if (nevents == 0 && client_sessions.size() == 0)
break;
- for (auto &pfd: pollfds) {
+ for (auto& pfd : pollfds) {
if (pfd.revents & (POLLERR | POLLHUP)) {
// some process died
DEBUG("detaching process");
- auto &session = client_sessions.at(pfd.fd);
- (void) session;
+ auto& session = client_sessions.at(pfd.fd);
+ (void)session;
DEBUG("%d has died", session.pid);
to_remove.push_back(pfd.fd);
} else if (pfd.revents & POLLIN) {
@@ -146,10 +144,14 @@
} else {
// someone wants to register a segment
DEBUG("got alloc info");
- auto &session = client_sessions.at(pfd.fd);
+ auto& session = client_sessions.at(pfd.fd);
AllocInfo info = session.socket.receive();
session.pid = info.pid;
- DEBUG("got alloc info: %d %d %s", (int)info.free, info.pid, info.filename);
+ DEBUG(
+ "got alloc info: %d %d %s",
+ (int)info.free,
+ info.pid,
+ info.filename);
if (info.free) {
free_used_object(info.filename);
} else {
@@ -161,22 +163,22 @@
}
}
- for (int fd: to_add)
+ for (int fd : to_add)
register_fd(fd);
to_add.clear();
- for (int fd: to_remove)
+ for (int fd : to_remove)
unregister_fd(fd);
to_remove.clear();
}
- for (auto &obj_name: used_objects) {
+ for (auto& obj_name : used_objects) {
DEBUG("freeing %s", obj_name.c_str());
shm_unlink(obj_name.c_str());
}
// Clean up file descriptors
- for (auto &pfd: pollfds) {
+ for (auto& pfd : pollfds) {
unregister_fd(pfd.fd);
}
// Clean up manager.sock
diff --git a/torch/lib/libshm/socket.h b/torch/lib/libshm/socket.h
index ef31e80..ee36f46 100644
--- a/torch/lib/libshm/socket.h
+++ b/torch/lib/libshm/socket.h
@@ -1,31 +1,33 @@
#pragma once
-#include <sys/types.h>
+#include <poll.h>
#include <sys/socket.h>
#include <sys/stat.h>
+#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
-#include <poll.h>
-#include <cstdio>
-#include <string>
-#include <sstream>
-#include <iostream>
-#include <cstring>
#include <cstddef>
+#include <cstdio>
+#include <cstring>
+#include <iostream>
+#include <sstream>
+#include <string>
-#include <libshm/err.h>
#include <libshm/alloc_info.h>
+#include <libshm/err.h>
class Socket {
-public:
+ public:
int socket_fd;
-protected:
+ protected:
Socket() {
SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
}
Socket(const Socket& other) = delete;
- Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { other.socket_fd = -1; };
+ Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
+ other.socket_fd = -1;
+ };
explicit Socket(int fd) : socket_fd(fd) {}
virtual ~Socket() {
@@ -33,7 +35,7 @@
close(socket_fd);
}
- struct sockaddr_un prepare_address(const char *path) {
+ struct sockaddr_un prepare_address(const char* path) {
struct sockaddr_un address;
address.sun_family = AF_UNIX;
strcpy(address.sun_path, path);
@@ -45,8 +47,8 @@
return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
}
- void recv(void *_buffer, size_t num_bytes) {
- char *buffer = (char*)_buffer;
+ void recv(void* _buffer, size_t num_bytes) {
+ char* buffer = (char*)_buffer;
size_t bytes_received = 0;
ssize_t step_received;
struct pollfd pfd = {0};
@@ -55,36 +57,39 @@
while (bytes_received < num_bytes) {
SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
if (pfd.revents & POLLIN) {
- SYSCHECK_ERR_RETURN_NEG1(step_received = ::read(socket_fd, buffer, num_bytes - bytes_received));
+ SYSCHECK_ERR_RETURN_NEG1(
+ step_received =
+ ::read(socket_fd, buffer, num_bytes - bytes_received));
if (step_received == 0)
throw std::runtime_error("Other end has closed the connection");
bytes_received += step_received;
buffer += step_received;
} else if (pfd.revents & (POLLERR | POLLHUP)) {
- throw std::runtime_error("An error occurred while waiting for the data");
+ throw std::runtime_error(
+ "An error occurred while waiting for the data");
} else {
- throw std::runtime_error("Shared memory manager connection has timed out");
+ throw std::runtime_error(
+ "Shared memory manager connection has timed out");
}
}
}
- void send(const void *_buffer, size_t num_bytes) {
- const char *buffer = (const char*)_buffer;
+ void send(const void* _buffer, size_t num_bytes) {
+ const char* buffer = (const char*)_buffer;
size_t bytes_sent = 0;
ssize_t step_sent;
while (bytes_sent < num_bytes) {
- SYSCHECK_ERR_RETURN_NEG1(step_sent = ::write(socket_fd, buffer, num_bytes));
+ SYSCHECK_ERR_RETURN_NEG1(
+ step_sent = ::write(socket_fd, buffer, num_bytes));
bytes_sent += step_sent;
buffer += step_sent;
}
}
-
-
};
-class ManagerSocket: public Socket {
-public:
- explicit ManagerSocket(int fd): Socket(fd) {}
+class ManagerSocket : public Socket {
+ public:
+ explicit ManagerSocket(int fd) : Socket(fd) {}
AllocInfo receive() {
AllocInfo info;
@@ -95,20 +100,19 @@
void confirm() {
send("OK", 2);
}
-
};
-
-class ManagerServerSocket: public Socket {
-public:
- explicit ManagerServerSocket(const std::string &path) {
+class ManagerServerSocket : public Socket {
+ public:
+ explicit ManagerServerSocket(const std::string& path) {
socket_path = path;
try {
struct sockaddr_un address = prepare_address(path.c_str());
size_t len = address_length(address);
- SYSCHECK_ERR_RETURN_NEG1(bind(socket_fd, (struct sockaddr *)&address, len));
+ SYSCHECK_ERR_RETURN_NEG1(
+ bind(socket_fd, (struct sockaddr*)&address, len));
SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
- } catch(std::exception &e) {
+ } catch (std::exception& e) {
SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
throw;
}
@@ -128,36 +132,38 @@
int client_fd;
struct sockaddr_un addr;
socklen_t addr_len = sizeof(addr);
- SYSCHECK_ERR_RETURN_NEG1(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
+ SYSCHECK_ERR_RETURN_NEG1(
+ client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
return ManagerSocket(client_fd);
}
std::string socket_path;
};
-class ClientSocket: public Socket {
-public:
- explicit ClientSocket(const std::string &path) {
+class ClientSocket : public Socket {
+ public:
+ explicit ClientSocket(const std::string& path) {
try {
struct sockaddr_un address = prepare_address(path.c_str());
size_t len = address_length(address);
- SYSCHECK_ERR_RETURN_NEG1(connect(socket_fd, (struct sockaddr *)&address, len));
- } catch(std::exception &e) {
+ SYSCHECK_ERR_RETURN_NEG1(
+ connect(socket_fd, (struct sockaddr*)&address, len));
+ } catch (std::exception& e) {
SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
throw;
}
}
- void register_allocation(AllocInfo &info) {
+ void register_allocation(AllocInfo& info) {
char buffer[3] = {0, 0, 0};
send(&info, sizeof(info));
recv(buffer, 2);
if (strcmp(buffer, "OK") != 0)
- throw std::runtime_error("Shared memory manager didn't respond with an OK");
+ throw std::runtime_error(
+ "Shared memory manager didn't respond with an OK");
}
- void register_deallocation(AllocInfo &info) {
+ void register_deallocation(AllocInfo& info) {
send(&info, sizeof(info));
}
-
};
diff --git a/torch/lib/libshm_windows/core.cpp b/torch/lib/libshm_windows/core.cpp
index d359bb7..4037d57 100644
--- a/torch/lib/libshm_windows/core.cpp
+++ b/torch/lib/libshm_windows/core.cpp
@@ -4,19 +4,23 @@
#include <libshm_windows/libshm.h>
-
-void libshm_init(const char *manager_exec_path) {
-}
+void libshm_init(const char* manager_exec_path) {}
static void deleteTHManagedMapAllocator(void* ptr) {
delete static_cast<THManagedMapAllocator*>(ptr);
}
-at::DataPtr THManagedMapAllocator::makeDataPtr(const char* manager_handle, const char* filename, int flags, ptrdiff_t size) {
- auto* context = new THManagedMapAllocator(manager_handle, filename, flags, size);
+at::DataPtr THManagedMapAllocator::makeDataPtr(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size) {
+ auto* context =
+ new THManagedMapAllocator(manager_handle, filename, flags, size);
return {context->data(), context, &deleteTHManagedMapAllocator, at::kCPU};
}
-THManagedMapAllocator* THManagedMapAllocator::fromDataPtr(const at::DataPtr& dptr) {
+THManagedMapAllocator* THManagedMapAllocator::fromDataPtr(
+ const at::DataPtr& dptr) {
return dptr.cast_context<THManagedMapAllocator>(&deleteTHManagedMapAllocator);
}
diff --git a/torch/lib/libshm_windows/libshm.h b/torch/lib/libshm_windows/libshm.h
index be22f39..5629e83 100644
--- a/torch/lib/libshm_windows/libshm.h
+++ b/torch/lib/libshm_windows/libshm.h
@@ -5,22 +5,32 @@
#ifdef __cplusplus
#ifdef SHM_EXPORTS
-# define SHM_API __declspec(dllexport)
+#define SHM_API __declspec(dllexport)
#else
-# define SHM_API __declspec(dllimport)
+#define SHM_API __declspec(dllimport)
#endif
-SHM_API void libshm_init(const char *manager_exec_path);
+SHM_API void libshm_init(const char* manager_exec_path);
class SHM_API THManagedMapAllocator : public at::RefcountedMapAllocator {
-public:
- THManagedMapAllocator(const char* manager_handle, const char* filename, int flags, ptrdiff_t size)
- : at::RefcountedMapAllocator(filename, flags, size) {}
+ public:
+ THManagedMapAllocator(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size)
+ : at::RefcountedMapAllocator(filename, flags, size) {}
- static at::DataPtr makeDataPtr(const char* manager_handle, const char* filename, int flags, ptrdiff_t size);
+ static at::DataPtr makeDataPtr(
+ const char* manager_handle,
+ const char* filename,
+ int flags,
+ ptrdiff_t size);
static THManagedMapAllocator* fromDataPtr(const at::DataPtr&);
- const char* manager_handle() const { return "no_manager"; }
+ const char* manager_handle() const {
+ return "no_manager";
+ }
};
#endif
diff --git a/torch/library.h b/torch/library.h
index 9fe93e6..51a2333 100644
--- a/torch/library.h
+++ b/torch/library.h
@@ -7,11 +7,11 @@
/// API can be used in a few ways:
///
/// * You can define new custom operators and classes with TORCH_LIBRARY(),
-/// making them available for use in both eager Python as well as in TorchScript.
-/// This API is modeled off of pybind11's `PYBIND11_MODULE` macro, as
-/// the provided functionality is similar (pybind11 lets you bind C++
-/// to Python only; `torch/library.h` lets you bind C++ simultaneously
-/// to Python and TorchScript).
+/// making them available for use in both eager Python as well as in
+/// TorchScript. This API is modeled off of pybind11's `PYBIND11_MODULE`
+/// macro, as the provided functionality is similar (pybind11 lets you bind
+/// C++ to Python only; `torch/library.h` lets you bind C++ simultaneously to
+/// Python and TorchScript).
///
/// * You can override existing operators with TORCH_LIBRARY_IMPL(),
/// providing a new implementation for these operators for a custom
@@ -58,9 +58,9 @@
/// }
/// ```
-#include <c10/core/DispatchKey.h>
-#include <ATen/core/op_registration/op_allowlist.h>
#include <ATen/core/op_registration/infer_schema.h>
+#include <ATen/core/op_registration/op_allowlist.h>
+#include <c10/core/DispatchKey.h>
#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#endif
@@ -101,64 +101,100 @@
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
// opaque to the user.
-public:
+ public:
/// This overload accepts function pointers, e.g., `CppFunction(&add_impl)`
template <typename Func>
- explicit CppFunction(Func* f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
- : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
- , cpp_signature_(c10::impl::CppSignature::make<Func>())
- , schema_(c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>())
- , debug_()
- {}
+ explicit CppFunction(
+ Func* f,
+ std::enable_if_t<
+ c10::guts::is_function_type<Func>::value,
+ std::nullptr_t> = nullptr)
+ : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
+ cpp_signature_(c10::impl::CppSignature::make<Func>()),
+ schema_(
+ c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
+ debug_() {}
- /// This overload accepts compile time function pointers, e.g., `CppFunction(TORCH_FN(add_impl))`
+ /// This overload accepts compile time function pointers, e.g.,
+ /// `CppFunction(TORCH_FN(add_impl))`
template <typename FuncPtr>
- explicit CppFunction(FuncPtr f, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr)
- : func_(c10::KernelFunction::makeFromUnboxedFunction(f))
- , cpp_signature_(c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
- , schema_(c10::detail::inferFunctionSchemaFromFunctor<typename FuncPtr::FuncType>())
- , debug_()
- {}
+ explicit CppFunction(
+ FuncPtr f,
+ std::enable_if_t<
+ c10::is_compile_time_function_pointer<FuncPtr>::value,
+ std::nullptr_t> = nullptr)
+ : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
+ cpp_signature_(
+ c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
+ schema_(c10::detail::inferFunctionSchemaFromFunctor<
+ typename FuncPtr::FuncType>()),
+ debug_() {}
- /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) { ... })`
+ /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
+ /// ... })`
template <typename Lambda>
- explicit CppFunction(Lambda&& f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
- : func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
- , cpp_signature_(c10::impl::CppSignature::make<Lambda>())
- , schema_(c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Lambda>>())
- , debug_()
- {}
+ explicit CppFunction(
+ Lambda&& f,
+ std::enable_if_t<
+ c10::guts::is_functor<std::decay_t<Lambda>>::value,
+ std::nullptr_t> = nullptr)
+ : func_(c10::KernelFunction::makeFromUnboxedLambda(
+ std::forward<Lambda>(f))),
+ cpp_signature_(c10::impl::CppSignature::make<Lambda>()),
+ schema_(c10::detail::inferFunctionSchemaFromFunctor<
+ std::decay_t<Lambda>>()),
+ debug_() {}
#if defined C10_MOBILE
- /// This overload accepts function pointers, e.g., `CppFunction(&add_impl, NoInferSchemaTag())`
+ /// This overload accepts function pointers, e.g., `CppFunction(&add_impl,
+ /// NoInferSchemaTag())`
template <typename Func>
- explicit CppFunction(Func* f, NoInferSchemaTag, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
- : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
- , cpp_signature_(c10::impl::CppSignature::make<Func>())
- // TODO: Don't go through WrapRuntimeKernelFunctor
- , schema_(nullptr)
- , debug_()
- {}
+ explicit CppFunction(
+ Func* f,
+ NoInferSchemaTag,
+ std::enable_if_t<
+ c10::guts::is_function_type<Func>::value,
+ std::nullptr_t> = nullptr)
+ : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
+ cpp_signature_(c10::impl::CppSignature::make<Func>())
+ // TODO: Don't go through WrapRuntimeKernelFunctor
+ ,
+ schema_(nullptr),
+ debug_() {}
- /// This overload accepts compile time function pointers, e.g., `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
+ /// This overload accepts compile time function pointers, e.g.,
+ /// `CppFunction(TORCH_FN(add_impl), NoInferSchemaTag())`
template <typename FuncPtr>
- explicit CppFunction(FuncPtr f, NoInferSchemaTag, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr)
- : func_(c10::KernelFunction::makeFromUnboxedFunction(f))
- , cpp_signature_(c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
- // TODO: Don't go through WrapRuntimeKernelFunctor
- , schema_(nullptr)
- , debug_()
- {}
+ explicit CppFunction(
+ FuncPtr f,
+ NoInferSchemaTag,
+ std::enable_if_t<
+ c10::is_compile_time_function_pointer<FuncPtr>::value,
+ std::nullptr_t> = nullptr)
+ : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
+ cpp_signature_(
+ c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
+ // TODO: Don't go through WrapRuntimeKernelFunctor
+ ,
+ schema_(nullptr),
+ debug_() {}
- /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) { ... }. NoInferSchemaTag())`
+ /// This overload accepts lambdas, e.g., `CppFunction([](const Tensor& self) {
+ /// ... }. NoInferSchemaTag())`
template <typename Lambda>
- explicit CppFunction(Lambda&& f, NoInferSchemaTag, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
- : func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
- , cpp_signature_(c10::impl::CppSignature::make<Lambda>())
- // TODO: Don't go through WrapRuntimeKernelFunctor
- , schema_(nullptr)
- , debug_()
- {}
+ explicit CppFunction(
+ Lambda&& f,
+ NoInferSchemaTag,
+ std::enable_if_t<
+ c10::guts::is_functor<std::decay_t<Lambda>>::value,
+ std::nullptr_t> = nullptr)
+ : func_(c10::KernelFunction::makeFromUnboxedLambda(
+ std::forward<Lambda>(f))),
+ cpp_signature_(c10::impl::CppSignature::make<Lambda>())
+ // TODO: Don't go through WrapRuntimeKernelFunctor
+ ,
+ schema_(nullptr),
+ debug_() {}
#endif
/// This creates a fallthrough function. Fallthrough functions
@@ -168,10 +204,9 @@
static CppFunction makeFallthrough() {
// TODO: more user friendly API
return CppFunction(
- c10::KernelFunction::makeFallthrough(),
- /* cpp_signature */ c10::nullopt, // not known for fallthroughs
- /* schema */ nullptr
- );
+ c10::KernelFunction::makeFallthrough(),
+ /* cpp_signature */ c10::nullopt, // not known for fallthroughs
+ /* schema */ nullptr);
}
/// \private
@@ -180,10 +215,9 @@
/// are not supported when called.
static CppFunction makeNamedNotSupported() {
return CppFunction(
- c10::KernelFunction::makeNamedNotSupported(),
- /* cpp_signature */ c10::nullopt, // not known for fallthroughs
- /* schema */ nullptr
- );
+ c10::KernelFunction::makeNamedNotSupported(),
+ /* cpp_signature */ c10::nullopt, // not known for fallthroughs
+ /* schema */ nullptr);
}
/// Create a function from a boxed kernel function with signature
@@ -192,26 +226,25 @@
/// in the native C++ calling convention. Boxed functions are
/// typically only used to register backend fallbacks via
/// torch::Library::fallback().
- template<c10::KernelFunction::BoxedKernelFunction* func>
+ template <c10::KernelFunction::BoxedKernelFunction* func>
static CppFunction makeFromBoxedFunction() {
// TODO: more user friendly API
return CppFunction(
- c10::KernelFunction::makeFromBoxedFunction<func>(),
- /* cpp_signature */ c10::nullopt, // not known for boxed functions
- /* schema */ nullptr
- );
+ c10::KernelFunction::makeFromBoxedFunction<func>(),
+ /* cpp_signature */ c10::nullopt, // not known for boxed functions
+ /* schema */ nullptr);
}
- // Variant that takes in a boxed kernel function with a plumbed DispatchKeySet.
- // See Note [Plumbing Keys Through The Dispatcher] for details.
- template<c10::KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
+ // Variant that takes in a boxed kernel function with a plumbed
+ // DispatchKeySet. See Note [Plumbing Keys Through The Dispatcher] for
+ // details.
+ template <c10::KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
static CppFunction makeFromBoxedFunction() {
// TODO: more user friendly API
return CppFunction(
- c10::KernelFunction::makeFromBoxedFunction<func>(),
- /* cpp_signature */ c10::nullopt, // not known for boxed functions
- /* schema */ nullptr
- );
+ c10::KernelFunction::makeFromBoxedFunction<func>(),
+ /* cpp_signature */ c10::nullopt, // not known for boxed functions
+ /* schema */ nullptr);
}
/// Create a function from a boxed kernel functor which defines
@@ -222,18 +255,22 @@
/// is managed by the functor; this is useful if you're writing an
/// adapter to some other implementation, e.g., a Python callable, which
/// is dynamically associated with the registered kernel.
- template<class KernelFunctor>
- static CppFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) {
+ template <class KernelFunctor>
+ static CppFunction makeFromBoxedFunctor(
+ std::unique_ptr<KernelFunctor> kernelFunctor) {
return CppFunction(
- c10::KernelFunction::makeFromBoxedFunctor(std::move(kernelFunctor)),
- /* cpp_signature */ c10::nullopt, // not known for boxed functions
- /* schema */ nullptr
- );
+ c10::KernelFunction::makeFromBoxedFunctor(std::move(kernelFunctor)),
+ /* cpp_signature */ c10::nullopt, // not known for boxed functions
+ /* schema */ nullptr);
}
/// Create a function from an unboxed kernel function.
/// This is typically used to register common operators.
- template<typename FuncPtr, std::enable_if_t<c10::guts::is_function_type<FuncPtr>::value, std::nullptr_t> = nullptr>
+ template <
+ typename FuncPtr,
+ std::enable_if_t<
+ c10::guts::is_function_type<FuncPtr>::value,
+ std::nullptr_t> = nullptr>
static CppFunction makeFromUnboxedFunction(FuncPtr* f) {
return CppFunction(f);
}
@@ -242,7 +279,11 @@
/// This is typically used to register common operators.
/// Compile time function pointers can be used to allow the compiler
/// to optimize (e.g. inline) calls to it.
- template<typename FuncPtr, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr>
+ template <
+ typename FuncPtr,
+ std::enable_if_t<
+ c10::is_compile_time_function_pointer<FuncPtr>::value,
+ std::nullptr_t> = nullptr>
static CppFunction makeFromUnboxedFunction(FuncPtr f) {
return CppFunction(f);
}
@@ -252,7 +293,7 @@
return std::move(*this);
}
-private:
+ private:
c10::optional<c10::DispatchKey> dispatch_key_;
c10::KernelFunction func_;
c10::optional<c10::impl::CppSignature> cpp_signature_;
@@ -268,7 +309,10 @@
// want users to use)
friend class Library;
- CppFunction(c10::KernelFunction func, c10::optional<c10::impl::CppSignature> cpp_signature, std::unique_ptr<c10::FunctionSchema> schema);
+ CppFunction(
+ c10::KernelFunction func,
+ c10::optional<c10::impl::CppSignature> cpp_signature,
+ std::unique_ptr<c10::FunctionSchema> schema);
};
/// \defgroup torch-dispatch-overloads torch::dispatch overloads
@@ -300,7 +344,7 @@
/// \ingroup torch-dispatch-overloads
template <typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
- auto deviceTypeToDispatchKey = [](c10::DeviceType t){
+ auto deviceTypeToDispatchKey = [](c10::DeviceType t) {
switch (t) {
// This list is synchronized with the k-constants in c10/core/DeviceType.h
case c10::DeviceType::CPU:
@@ -322,9 +366,12 @@
case c10::DeviceType::HPU:
return c10::DispatchKey::HPU;
default:
- TORCH_CHECK(false,
- "Device type ", t, " cannot be overloaded at dispatch time, "
- "please file a bug report explaining what you were trying to do.");
+ TORCH_CHECK(
+ false,
+ "Device type ",
+ t,
+ " cannot be overloaded at dispatch time, "
+ "please file a bug report explaining what you were trying to do.");
}
};
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
@@ -341,7 +388,8 @@
/// // Default alias analysis (FROM_SCHEMA)
/// m.def("def3(Tensor self) -> Tensor");
/// // Pure function alias analysis
-/// m.def(torch::schema("def3(Tensor self) -> Tensor", c10::AliasAnalysisKind::PURE_FUNCTION));
+/// m.def(torch::schema("def3(Tensor self) -> Tensor",
+/// c10::AliasAnalysisKind::PURE_FUNCTION));
/// ```
///
/// \ingroup torch-schema-overloads
@@ -364,25 +412,30 @@
/// rvalues.
///
/// \ingroup torch-schema-overloads
-inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { return std::move(s); }
+inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) {
+ return std::move(s);
+}
namespace detail {
- inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::FunctionSchema&& s) {
- return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s));
+inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
+ c10::FunctionSchema&& s) {
+ return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s));
+}
+inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
+ c10::OperatorName&& n) {
+ return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n));
+}
+inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(
+ const char* str) {
+ auto s = torch::jit::parseSchemaOrName(str);
+ if (s.is_right()) {
+ s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
}
- inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::OperatorName&& n) {
- return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n));
- }
- inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(const char* str) {
- auto s = torch::jit::parseSchemaOrName(str);
- if (s.is_right()) {
- s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
- }
- return s;
- }
+ return s;
+}
- class TorchLibraryInit;
+class TorchLibraryInit;
} // namespace detail
@@ -401,40 +454,50 @@
// Instead of doing this, we have a different mechanism centered around the
// concept of a SelectiveStr. A selective name is like a const char* string,
// except it also carries at compile time a boolean saying whether or not a
-// registration should actually happen or not. We then have extra overloads which
-// bypass registration entirely if a selective name is disabled. We do a
+// registration should actually happen or not. We then have extra overloads
+// which bypass registration entirely if a selective name is disabled. We do a
// constexpr test to see if a operator should be enabled or not; this is
// currently implemented in ATen/core/op_registration/op_allowlist.h
namespace detail {
- // dummy class for non selected custom torchbind classes
- class ClassNotSelected {
- public:
- ClassNotSelected& def_pickle(...){ return *this;}
- ClassNotSelected& def(...){ return *this;}
- };
+// dummy class for non selected custom torchbind classes
+class ClassNotSelected {
+ public:
+ ClassNotSelected& def_pickle(...) {
+ return *this;
+ }
+ ClassNotSelected& def(...) {
+ return *this;
+ }
+};
- // A SelectiveStr is like a const char*, except that it also comes
- // with a type brand that says whether or not the name is enabled or
- // not. If the string is disabled, then (at compile time) we DON'T generate
- // a registration call for it. This class is not intended to be called
- // directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below
- // to create it.
- template <bool enabled>
- class SelectiveStr {
- public:
- constexpr explicit SelectiveStr(const char* name) : name_(name) {}
- constexpr operator const char*() { return name_; }
- private:
- const char* name_;
- };
+// A SelectiveStr is like a const char*, except that it also comes
+// with a type brand that says whether or not the name is enabled or
+// not. If the string is disabled, then (at compile time) we DON'T generate
+// a registration call for it. This class is not intended to be called
+// directly; use TORCH_SELECTIVE_NAME or TORCH_SELECTIVE_SCHEMA macros below
+// to create it.
+template <bool enabled>
+class SelectiveStr {
+ public:
+ constexpr explicit SelectiveStr(const char* name) : name_(name) {}
+ constexpr operator const char*() {
+ return name_;
+ }
-#define TORCH_SELECTIVE_CLASS(n) torch::detail::SelectiveStr<c10::impl::custom_class_allowlist_check(n)>(n)
-#define TORCH_SELECTIVE_NAME(n) torch::detail::SelectiveStr<c10::impl::op_allowlist_check(n)>(n)
-#define TORCH_SELECTIVE_SCHEMA(n) torch::detail::SelectiveStr<c10::impl::schema_allowlist_check(n)>(n)
+ private:
+ const char* name_;
+};
-}
+#define TORCH_SELECTIVE_CLASS(n) \
+ torch::detail::SelectiveStr<c10::impl::custom_class_allowlist_check(n)>(n)
+#define TORCH_SELECTIVE_NAME(n) \
+ torch::detail::SelectiveStr<c10::impl::op_allowlist_check(n)>(n)
+#define TORCH_SELECTIVE_SCHEMA(n) \
+ torch::detail::SelectiveStr<c10::impl::schema_allowlist_check(n)>(n)
+
+} // namespace detail
/// This object provides the API for defining operators and providing
/// implementations at dispatch keys. Typically, a torch::Library
@@ -461,7 +524,7 @@
/// ```
///
class TORCH_API Library final {
-public:
+ public:
/// \private
///
/// Which type of macro produced this Library
@@ -475,7 +538,12 @@
///
/// Use TORCH_LIBRARY() or TORCH_LIBRARY_IMPL() instead of using these
/// constructors directly
- Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line);
+ Library(
+ Kind kind,
+ std::string ns,
+ c10::optional<c10::DispatchKey> k,
+ const char* file,
+ uint32_t line);
Library(const Library&) = delete;
Library& operator=(const Library&) = delete;
@@ -548,7 +616,8 @@
template <typename NameOrSchema, typename Func>
Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
- auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
+ auto name_or_schema = detail::constructSchemaOrName(
+ std::forward<NameOrSchema>(raw_name_or_schema));
return _def(std::move(name_or_schema), std::move(f));
}
@@ -575,12 +644,12 @@
Library& impl(Name name, Func&& raw_f) & {
// TODO: need to raise an error when you impl a function that has a
// catch all def
- #if defined C10_MOBILE
+#if defined C10_MOBILE
CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
- #else
+#else
CppFunction f(std::forward<Func>(raw_f));
- #endif
- return _impl(name, std::move(f));
+#endif
+ return _impl(name, std::move(f));
}
#if defined C10_MOBILE
@@ -608,36 +677,50 @@
/// the dispatch key for the entire block in TORCH_LIBRARY_IMPL()
template <typename Name, typename Dispatch, typename Func>
Library& impl(Name name, Dispatch&& key, Func&& raw_f) & {
- return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
+ return impl(
+ name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
}
template <typename Name, typename Func>
Library& impl_UNBOXED(Name name, Func* raw_f) & {
- static_assert(c10::guts::false_t<Func>(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
+ static_assert(
+ c10::guts::false_t<Func>(),
+ ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
return *this;
}
- // These overloads cover cases when a SelectiveStr (see Note [Selective build])
- // has been disabled at compile time. In that case, don't generate any code
- // referencing the passed in functions at all.
- Library& def(detail::SelectiveStr<false>) & { return *this; }
+ // These overloads cover cases when a SelectiveStr (see Note [Selective
+ // build]) has been disabled at compile time. In that case, don't generate
+ // any code referencing the passed in functions at all.
+ Library& def(detail::SelectiveStr<false>) & {
+ return *this;
+ }
Library& def(detail::SelectiveStr<true> raw_schema) & {
- return def(raw_schema.operator const char *());
+ return def(raw_schema.operator const char*());
}
template <typename Func>
- Library& def(detail::SelectiveStr<false>, Func&& raw_f) & { return *this; }
+ Library& def(detail::SelectiveStr<false>, Func&& raw_f) & {
+ return *this;
+ }
template <typename Func>
Library& def(detail::SelectiveStr<true> raw_name_or_schema, Func&& raw_f) & {
- return def(raw_name_or_schema.operator const char *(), std::forward<Func>(raw_f));
+ return def(
+ raw_name_or_schema.operator const char*(), std::forward<Func>(raw_f));
}
template <typename Func>
- Library& impl(detail::SelectiveStr<false>, Func&& raw_f) & { return *this; }
+ Library& impl(detail::SelectiveStr<false>, Func&& raw_f) & {
+ return *this;
+ }
template <typename Dispatch, typename Func>
- Library& impl(detail::SelectiveStr<false>, Dispatch&& key, Func&& raw_f) & { return *this; }
+ Library& impl(detail::SelectiveStr<false>, Dispatch&& key, Func&& raw_f) & {
+ return *this;
+ }
template <typename Func>
Library& impl_UNBOXED(detail::SelectiveStr<false> name, Func* raw_f) & {
- static_assert(c10::guts::false_t<Func>(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
+ static_assert(
+ c10::guts::false_t<Func>(),
+ ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
return *this;
}
@@ -646,12 +729,20 @@
return impl(name.operator const char*(), std::forward<Func>(raw_f));
}
template <typename Dispatch, typename Func>
- Library& impl(detail::SelectiveStr<true> name, Dispatch&& key, Func&& raw_f) & {
- return impl(name.operator const char*(), std::forward<Dispatch>(key), std::forward<Func>(raw_f));
+ Library& impl(
+ detail::SelectiveStr<true> name,
+ Dispatch&& key,
+ Func&& raw_f) & {
+ return impl(
+ name.operator const char*(),
+ std::forward<Dispatch>(key),
+ std::forward<Func>(raw_f));
}
template <typename Func>
Library& impl_UNBOXED(detail::SelectiveStr<true> name, Func* raw_f) & {
- static_assert(c10::guts::false_t<Func>(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
+ static_assert(
+ c10::guts::false_t<Func>(),
+ ".impl_UNBOXED(...) was removed. Please use .impl(...) instead.");
return *this;
}
@@ -689,16 +780,17 @@
template <class CurClass>
inline torch::class_<CurClass> class_(const std::string& className);
- // These overloads enable the use of selective build on classes registered within
- // a library. The API is the same as before with 1 minor change. Instead of
- // m.class_<foo>("foo") you instead do m.class_<foo>(TORCH_SELECTIVE_CLASS("foo"))
+ // These overloads enable the use of selective build on classes registered
+ // within a library. The API is the same as before with 1 minor change.
+ // Instead of m.class_<foo>("foo") you instead do
+ // m.class_<foo>(TORCH_SELECTIVE_CLASS("foo"))
template <class CurClass>
inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className);
template <class CurClass>
inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className);
-private:
+ private:
Kind kind_;
c10::optional<std::string> ns_;
c10::optional<c10::DispatchKey> dispatch_key_;
@@ -711,8 +803,12 @@
// Non-user visible actual implementations of functions. These aren't
// public because we only implement & qualifier and not && qualifier
- Library& _def(c10::FunctionSchema&& schema, c10::OperatorName* out_name = nullptr) &;
- Library& _def(c10::either<c10::OperatorName, c10::FunctionSchema>&&, CppFunction&& f) &;
+ Library& _def(
+ c10::FunctionSchema&& schema,
+ c10::OperatorName* out_name = nullptr) &;
+ Library& _def(
+ c10::either<c10::OperatorName, c10::FunctionSchema>&&,
+ CppFunction&& f) &;
Library& _impl(const char* name, CppFunction&& f) &;
Library& _fallback(CppFunction&& f) &;
};
@@ -720,12 +816,19 @@
namespace detail {
class TorchLibraryInit final {
-private:
+ private:
using InitFn = void(Library&);
Library lib_;
-public:
- TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
- : lib_(kind, ns, k, file, line) {
+
+ public:
+ TorchLibraryInit(
+ Library::Kind kind,
+ InitFn* fn,
+ const char* ns,
+ c10::optional<c10::DispatchKey> k,
+ const char* file,
+ uint32_t line)
+ : lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
@@ -734,12 +837,10 @@
} // namespace torch
-
// NB: The EXACT NAMING of the initializer functions (e.g.,
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
// see the regexes at tools/code_analyzer/run_analyzer.sh
-
/// Macro for defining a function that will be run at static
/// initialization time to define a library of operators in the
/// namespace `ns` (must be a valid C++ identifier, no quotes).
@@ -759,14 +860,16 @@
/// The `m` argument is bound to a torch::Library that is used to
/// register operators. There may only be one TORCH_LIBRARY()
/// for any given namespace.
-#define TORCH_LIBRARY(ns, m) \
- static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
- static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
- torch::Library::DEF, \
- &TORCH_LIBRARY_init_ ## ns, \
- #ns, c10::nullopt, __FILE__, __LINE__ \
- ); \
- void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
+#define TORCH_LIBRARY(ns, m) \
+ static void TORCH_LIBRARY_init_##ns(torch::Library&); \
+ static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \
+ torch::Library::DEF, \
+ &TORCH_LIBRARY_init_##ns, \
+ #ns, \
+ c10::nullopt, \
+ __FILE__, \
+ __LINE__); \
+ void TORCH_LIBRARY_init_##ns(torch::Library& m)
/// \private
///
@@ -776,23 +879,28 @@
/// within the same namespace cannot be easily put into one macro block
/// (this is mostly the case for custom ops in fbcode that were ported from
/// the old API)
-#define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID)
+#define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID)
/// \private
///
-/// The above macro requires an extra unique identifier (uid) to prevent variable name collisions
-/// This can happen if TORCH_LIBRARY_FRAGMENT is called multiple times with the same namespace
-/// in the same translation unit.
-/// Note that the TORCH_LIBRARY variant doesn't run into this problem, because it enforces
-/// that it can only be called once for a given namespace.
-#define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
- static void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid) (torch::Library&); \
- static const torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _, uid) ( \
- torch::Library::FRAGMENT, \
- &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid), \
- #ns, c10::nullopt, __FILE__, __LINE__ \
- ); \
- void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid) (torch::Library& m)
+/// The above macro requires an extra unique identifier (uid) to prevent
+/// variable name collisions This can happen if TORCH_LIBRARY_FRAGMENT is called
+/// multiple times with the same namespace in the same translation unit. Note
+/// that the TORCH_LIBRARY variant doesn't run into this problem, because it
+/// enforces that it can only be called once for a given namespace.
+#define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
+ static void C10_CONCATENATE( \
+ TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library&); \
+ static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
+ TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
+ torch::Library::FRAGMENT, \
+ &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
+ #ns, \
+ c10::nullopt, \
+ __FILE__, \
+ __LINE__); \
+ void C10_CONCATENATE( \
+ TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(torch::Library & m)
/// Macro for defining a function that will be run at static
/// initialization time to define operator overrides for dispatch key
@@ -838,22 +946,29 @@
/// \private
///
-/// The above macro requires an extra unique identifier (uid) to prevent variable name collisions.
-/// This can happen if TORCH_LIBRARY_IMPL is called multiple times with the same namespace
-/// and dispatch key in the same translation unit.
-#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \
- static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid) (torch::Library&); \
- static const torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ## _, uid) ( \
- torch::Library::IMPL, \
- c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::k)>( \
- []() { return & C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid); }, \
- []() { return [](torch::Library&) -> void {}; } \
- ), \
- #ns, c10::make_optional(c10::DispatchKey::k), \
- __FILE__, __LINE__ \
- ); \
- void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid) (torch::Library& m)
-
+/// The above macro requires an extra unique identifier (uid) to prevent
+/// variable name collisions. This can happen if TORCH_LIBRARY_IMPL is called
+/// multiple times with the same namespace and dispatch key in the same
+/// translation unit.
+#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \
+ static void C10_CONCATENATE( \
+ TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&); \
+ static const torch::detail::TorchLibraryInit C10_CONCATENATE( \
+ TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
+ torch::Library::IMPL, \
+ c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( \
+ c10::DispatchKey::k)>( \
+ []() { \
+ return &C10_CONCATENATE( \
+ TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid); \
+ }, \
+ []() { return [](torch::Library&) -> void {}; }), \
+ #ns, \
+ c10::make_optional(c10::DispatchKey::k), \
+ __FILE__, \
+ __LINE__); \
+ void C10_CONCATENATE( \
+ TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)
// These are variants of the macros above which are to be used for testing (they
// don't setup the static initializer, so you can control the visibility of
@@ -863,9 +978,16 @@
// code analyzer and will be incorrectly analyzed in those situations.
/// \private
-#define MAKE_TORCH_LIBRARY(ns) torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
+#define MAKE_TORCH_LIBRARY(ns) \
+ torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
/// \private
-#define MAKE_TORCH_LIBRARY_IMPL(ns, k) torch::Library(torch::Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__)
+#define MAKE_TORCH_LIBRARY_IMPL(ns, k) \
+ torch::Library( \
+ torch::Library::IMPL, \
+ #ns, \
+ c10::make_optional(c10::DispatchKey::k), \
+ __FILE__, \
+ __LINE__)
// Make the custom class API visible, so it is available from
// torch::Library.
diff --git a/torch/script.h b/torch/script.h
index 1be4030..5851067 100644
--- a/torch/script.h
+++ b/torch/script.h
@@ -1,13 +1,13 @@
#pragma once
#include <torch/csrc/api/include/torch/types.h>
+#include <torch/csrc/autograd/InferenceMode.h>
+#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/grad_mode.h>
-#include <torch/csrc/autograd/InferenceMode.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/pickle.h>
-#include <torch/csrc/autograd/custom_function.h>
#include <torch/custom_class.h>
#include <ATen/ATen.h>
diff --git a/torch/utils/benchmark/utils/timeit_template.cpp b/torch/utils/benchmark/utils/timeit_template.cpp
index 6396d41..d739b70 100644
--- a/torch/utils/benchmark/utils/timeit_template.cpp
+++ b/torch/utils/benchmark/utils/timeit_template.cpp
@@ -11,34 +11,32 @@
#include <c10/util/irange.h>
#include <pybind11/pybind11.h>
-#include <c10/util/irange.h>
#include <torch/extension.h>
// Global setup. (e.g. #includes)
// GLOBAL_SETUP_TEMPLATE_LOCATION
double timeit(int n) {
- pybind11::gil_scoped_release no_gil;
+ pybind11::gil_scoped_release no_gil;
- // Setup
- // SETUP_TEMPLATE_LOCATION
+ // Setup
+ // SETUP_TEMPLATE_LOCATION
- {
- // Warmup
- // STMT_TEMPLATE_LOCATION
- }
+ {
+ // Warmup
+ // STMT_TEMPLATE_LOCATION
+ }
- // Main loop
- auto start_time = std::chrono::high_resolution_clock::now();
- for(const auto loop_idx : c10::irange(n)) {
- (void)loop_idx;
- // STMT_TEMPLATE_LOCATION
- }
- auto end_time = std::chrono::high_resolution_clock::now();
- return std::chrono::duration<double>(end_time - start_time).count();
+ // Main loop
+ auto start_time = std::chrono::high_resolution_clock::now();
+ for (const auto loop_idx : c10::irange(n)) {
+ (void)loop_idx;
+ // STMT_TEMPLATE_LOCATION
+ }
+ auto end_time = std::chrono::high_resolution_clock::now();
+ return std::chrono::duration<double>(end_time - start_time).count();
}
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("timeit", &timeit);
+ m.def("timeit", &timeit);
}
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
index 911a1a1..cd41f0d 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
@@ -2,35 +2,34 @@
#include <callgrind.h>
#include <pybind11/pybind11.h>
-
bool _valgrind_supported_platform() {
- #if defined(NVALGRIND)
- return false;
- #else
- return true;
- #endif
+#if defined(NVALGRIND)
+ return false;
+#else
+ return true;
+#endif
}
void _valgrind_toggle() {
- #if defined(NVALGRIND)
- TORCH_CHECK(false, "Valgrind is not supported.");
- #else
- CALLGRIND_TOGGLE_COLLECT;
- #endif
+#if defined(NVALGRIND)
+ TORCH_CHECK(false, "Valgrind is not supported.");
+#else
+ CALLGRIND_TOGGLE_COLLECT;
+#endif
}
void _valgrind_toggle_and_dump_stats() {
- #if defined(NVALGRIND)
- TORCH_CHECK(false, "Valgrind is not supported.");
- #else
- // NB: See note in Module.cpp
- CALLGRIND_TOGGLE_COLLECT;
- CALLGRIND_DUMP_STATS;
- #endif
+#if defined(NVALGRIND)
+ TORCH_CHECK(false, "Valgrind is not supported.");
+#else
+ // NB: See note in Module.cpp
+ CALLGRIND_TOGGLE_COLLECT;
+ CALLGRIND_DUMP_STATS;
+#endif
}
PYBIND11_MODULE(callgrind_bindings, m) {
- m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
- m.def("_valgrind_toggle", &_valgrind_toggle);
- m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats);
+ m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
+ m.def("_valgrind_toggle", &_valgrind_toggle);
+ m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats);
}
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
index 1c6bd13..bf97cf4 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
@@ -8,8 +8,8 @@
sections with user provided statements.
*/
-#include <callgrind.h>
#include <c10/util/irange.h>
+#include <callgrind.h>
#include <torch/torch.h>
#include <string>
@@ -21,45 +21,44 @@
static_assert(false);
#endif
-
int main(int argc, char* argv[]) {
- // This file should only be called inside of `Timer`, so we can adopt a
- // very simple and rigid argument parsing scheme.
- TORCH_CHECK(argc == 9);
- TORCH_CHECK(std::string(argv[1]) == "--number");
- auto number = std::stoi(argv[2]);
+ // This file should only be called inside of `Timer`, so we can adopt a
+ // very simple and rigid argument parsing scheme.
+ TORCH_CHECK(argc == 9);
+ TORCH_CHECK(std::string(argv[1]) == "--number");
+ auto number = std::stoi(argv[2]);
- TORCH_CHECK(std::string(argv[3]) == "--number_warmup");
- auto number_warmup = std::stoi(argv[4]);
+ TORCH_CHECK(std::string(argv[3]) == "--number_warmup");
+ auto number_warmup = std::stoi(argv[4]);
- TORCH_CHECK(std::string(argv[5]) == "--repeats");
- auto repeats = std::stoi(argv[6]);
+ TORCH_CHECK(std::string(argv[5]) == "--repeats");
+ auto repeats = std::stoi(argv[6]);
- TORCH_CHECK(std::string(argv[7]) == "--number_threads");
- auto number_threads = std::stoi(argv[8]);
- torch::set_num_threads(number_threads);
+ TORCH_CHECK(std::string(argv[7]) == "--number_threads");
+ auto number_threads = std::stoi(argv[8]);
+ torch::set_num_threads(number_threads);
- // Setup
- // SETUP_TEMPLATE_LOCATION
+ // Setup
+ // SETUP_TEMPLATE_LOCATION
- // Warmup
- for(const auto i : c10::irange(number_warmup)) {
- (void)i;
- // STMT_TEMPLATE_LOCATION
+ // Warmup
+ for (const auto i : c10::irange(number_warmup)) {
+ (void)i;
+ // STMT_TEMPLATE_LOCATION
+ }
+
+ // Main loop
+ for (const auto repeat : c10::irange(repeats)) {
+ (void)repeat;
+ CALLGRIND_TOGGLE_COLLECT;
+
+ for (const auto i : c10::irange(number)) {
+ (void)i;
+ // STMT_TEMPLATE_LOCATION
}
- // Main loop
- for(const auto repeat : c10::irange(repeats)) {
- (void)repeat;
- CALLGRIND_TOGGLE_COLLECT;
-
- for(const auto i : c10::irange(number)) {
- (void)i;
- // STMT_TEMPLATE_LOCATION
- }
-
- // NB: See note in Module.cpp
- CALLGRIND_TOGGLE_COLLECT;
- CALLGRIND_DUMP_STATS;
- }
+ // NB: See note in Module.cpp
+ CALLGRIND_TOGGLE_COLLECT;
+ CALLGRIND_DUMP_STATS;
+ }
}