blob: f2ff069435fe277f6e3a92263daefeae17c556ec [file] [log] [blame]
#include <torch/extension.h>
#include <ATen/CPUFloatType.h>
#include <ATen/Type.h>
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/detail/ComplexHooksInterface.h>
#include "ATen/Allocator.h"
#include "ATen/CPUGenerator.h"
#include "ATen/DeviceGuard.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Utils.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/core/Half.h"
#include "ATen/core/TensorImpl.h"
#include "ATen/core/UndefinedTensorImpl.h"
#include "c10/util/Optional.h"
#include <cstddef>
#include <functional>
#include <memory>
#include <utility>
#include "ATen/Config.h"
namespace at {
struct CPUComplexFloatType : public at::CPUTypeDefault {
CPUComplexFloatType()
: CPUTypeDefault(
CPUTensorId(),
/*is_variable=*/false,
/*is_undefined=*/false) {}
ScalarType scalarType() const override;
caffe2::TypeMeta typeMeta() const override;
Backend backend() const override;
const char* toString() const override;
size_t elementSizeInBytes() const override;
TypeID ID() const override;
Tensor& s_copy_(Tensor& self, const Tensor& src, bool non_blocking)
const override;
Tensor& _s_copy_from(const Tensor& self, Tensor& dst, bool non_blocking)
const override;
Tensor empty(IntList size, const TensorOptions & options) const override {
// TODO: Upstream this
int64_t numel = 1;
for (auto s : size) {
numel *= s;
}
Storage s{c10::make_intrusive<StorageImpl>(
scalarTypeToTypeMeta(ScalarType::ComplexFloat),
numel,
getCPUAllocator(),
/* resizable */ true)};
Tensor t{c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
std::move(s),
at::CPUTensorId(),
/* is_variable */ false)};
return t;
}
};
struct ComplexHooks : public at::ComplexHooksInterface {
ComplexHooks(ComplexHooksArgs) {}
void registerComplexTypes(Context* context) const override {
context->registerType(
Backend::CPU, ScalarType::ComplexFloat, new CPUComplexFloatType());
}
};
ScalarType CPUComplexFloatType::scalarType() const {
return ScalarType::ComplexFloat;
}
caffe2::TypeMeta CPUComplexFloatType::typeMeta() const {
return scalarTypeToTypeMeta(ScalarType::ComplexFloat);
}
Backend CPUComplexFloatType::backend() const {
return Backend::CPU;
}
const char* CPUComplexFloatType::toString() const {
return "CPUComplexFloatType";
}
TypeID CPUComplexFloatType::ID() const {
return TypeID::CPUComplexFloat;
}
size_t CPUComplexFloatType::elementSizeInBytes() const {
return sizeof(float);
}
Tensor& CPUComplexFloatType::s_copy_(
Tensor& dst,
const Tensor& src,
bool non_blocking) const {
AT_ERROR("not yet supported");
}
Tensor& CPUComplexFloatType::_s_copy_from(
const Tensor& src,
Tensor& dst,
bool non_blocking) const {
AT_ERROR("not yet supported");
}
REGISTER_COMPLEX_HOOKS(ComplexHooks);
} // namespace at
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { }