| #pragma once |
| #include <ATen/native/TensorIterator.h> |
| #include <c10/util/SmallBuffer.h> |
| #include <c10/util/irange.h> |
| |
| namespace at { |
| |
| struct DimCounter { |
| DimCounter(IntArrayRef shape, Range range); |
| |
| void increment(const std::array<int64_t, 2>& step); |
| bool is_done() const; |
| std::array<int64_t, 2> max_2d_step() const; |
| |
| IntArrayRef shape; |
| Range range; |
| c10::SmallBuffer<int64_t, 4> values; |
| int64_t offset; |
| }; |
| |
| namespace internal { |
| |
| inline void get_data_ptrs( |
| char** ptrs, |
| ArrayRef<char*> base, |
| IntArrayRef strides, |
| IntArrayRef counter) { |
| const int64_t ntensors = base.size(); |
| const int64_t ndim = counter.size(); |
| std::copy(base.begin(), base.end(), ptrs); |
| for (const auto dim : c10::irange(ndim)) { |
| int64_t value = counter[dim]; |
| for (const auto arg : c10::irange(ntensors)) { |
| ptrs[arg] += value * strides[dim * ntensors + arg]; |
| } |
| } |
| } |
| |
| inline void serial_for_each( |
| IntArrayRef shape, |
| IntArrayRef strides, |
| char** base_ptrs, |
| size_t ntensors, |
| typename TensorIteratorBase::loop2d_t loop, |
| Range range) { |
| const auto ndim = shape.size(); |
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
| strides.size() == ntensors * std::max(size_t{2}, ndim)); |
| |
| if (ndim <= 1) { |
| if (range.begin == 0) { |
| loop(base_ptrs, strides.data(), range.size(), 1); |
| } else { |
| c10::SmallBuffer<char*, 4> ptrs(ntensors); |
| get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin}); |
| loop(ptrs.data(), strides.data(), range.size(), 1); |
| } |
| } else { |
| c10::SmallBuffer<char*, 4> ptrs(ntensors); |
| auto counter = DimCounter(shape, range); |
| while (!counter.is_done()) { |
| get_data_ptrs( |
| ptrs.data(), {base_ptrs, ntensors}, strides, counter.values); |
| auto step = counter.max_2d_step(); |
| loop(ptrs.data(), strides.data(), step[0], step[1]); |
| counter.increment(step); |
| } |
| } |
| } |
| |
| } // namespace internal |
| } // namespace at |