blob: 55eb3cf24999ce94ab67944ed967b621e7fa876f [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <torch/library.h>
namespace custom {
namespace native {
using at::Tensor;
using c10::ScalarType;
// mul4(Tensor input) -> Tensor
Tensor mul4_impl(const Tensor& in) {
// naive approach
at::Tensor out = at::zeros_like(in);
out.copy_(in);
out.mul_(4);
return out;
}
TORCH_LIBRARY_FRAGMENT(my_ops, m) {
m.def("my_ops::mul4(Tensor input) -> Tensor");
}
TORCH_LIBRARY_IMPL(my_ops, CompositeExplicitAutograd, m) {
m.impl("mul4", TORCH_FN(mul4_impl));
}
} // namespace native
} // namespace custom