blob: e51af27068c733dce46bf5e3ea2333a154909a21 [file] [log] [blame]
#pragma once
#include <cstddef>
#include <stdint.h>
#include "ATen/ScalarType.h"
namespace at {
template<typename T, size_t N>
class TensorAccessorBase {
public:
TensorAccessorBase(T * data_, const int64_t * sizes_, const int64_t * strides_)
: data_(data_), sizes_(sizes_), strides_(strides_) {}
IntList sizes() {
return IntList(sizes_,N);
}
IntList strides() {
return IntList(strides_,N);
}
int64_t stride(int64_t i) { return strides()[i]; }
int64_t size(int64_t i) { return sizes()[i]; }
protected:
T * data_;
const int64_t* sizes_;
const int64_t* strides_;
};
template<typename T, size_t N>
class TensorAccessor : public TensorAccessorBase<T,N> {
public:
TensorAccessor(T * data_, const int64_t * sizes_, const int64_t * strides_)
: TensorAccessorBase<T,N>(data_,sizes_,strides_) {}
TensorAccessor<T,N-1> operator[](int64_t i) {
return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
};
template<typename T>
class TensorAccessor<T,1> : public TensorAccessorBase<T,1> {
public:
TensorAccessor(T * data_, const int64_t * sizes_, const int64_t * strides_)
: TensorAccessorBase<T,1>(data_,sizes_,strides_) {}
T & operator[](int64_t i) {
return this->data_[this->strides_[0]*i];
}
};
}