| #pragma once | 
 |  | 
 | #include "ATen/Tensor.h" | 
 | #include <sstream> | 
 |  | 
 | namespace at { | 
 |  | 
 | inline std::tuple<Tensor> expand_inplace(const Tensor &tensor, const Tensor &to_expand) { | 
 |   if (tensor.sizes().equals(to_expand.sizes())) { | 
 |     return std::make_tuple(to_expand); | 
 |   } | 
 |  | 
 |   return std::make_tuple(to_expand.expand(tensor.sizes())); | 
 | } | 
 |  | 
 | inline std::tuple<Tensor, Tensor> expand_inplace(const Tensor &tensor, const Tensor &to_expand1, const Tensor &to_expand2) { | 
 |   if (tensor.sizes().equals(to_expand1.sizes()) && tensor.sizes().equals((to_expand2.sizes()))) { | 
 |     return std::make_tuple(to_expand1, to_expand2); | 
 |   } | 
 |  | 
 |   return std::make_tuple(to_expand1.expand(tensor.sizes()), to_expand2.expand(tensor.sizes())); | 
 | } | 
 |  | 
 | inline std::vector<int64_t> infer_size2(IntList a, IntList b) { | 
 |   auto dimsA = a.size(); | 
 |   auto dimsB = b.size(); | 
 |   ptrdiff_t ndim = dimsA > dimsB ? dimsA : dimsB; | 
 |   std::vector<int64_t> expandedSizes(ndim); | 
 |  | 
 |   for (long i = ndim - 1; i >= 0; --i) { | 
 |     long offset = ndim - 1 - i; | 
 |     long dimA = dimsA - 1 - offset; | 
 |     long dimB = dimsB - 1 - offset; | 
 |     long sizeA = (dimA >= 0) ? a[dimA] : 1; | 
 |     long sizeB = (dimB >= 0) ? b[dimB] : 1; | 
 |     if (sizeA == sizeB || sizeA == 1 || sizeB == 1) { | 
 |       expandedSizes[i] = std::max(sizeA, sizeB); | 
 |     } else { | 
 |       std::ostringstream oss; | 
 |       oss << "The size of tensor a (" << sizeA << ") must match the size of tensor b (" | 
 |           << sizeB << ") at non-singleton dimension " << i; | 
 |       throw std::runtime_error(oss.str()); | 
 |     } | 
 |   } | 
 |  | 
 |   return expandedSizes; | 
 | } | 
 |  | 
 | inline std::tuple<Tensor, Tensor> expand_outplace(const Tensor &to_expand1, const Tensor &to_expand2) { | 
 |   if (to_expand1.sizes().equals(to_expand2.sizes())) { | 
 |     return std::make_tuple(to_expand1, to_expand2); | 
 |   } | 
 |  | 
 |   auto expanded_size = infer_size2(to_expand1.sizes(), to_expand2.sizes()); | 
 |   return std::make_tuple(to_expand1.expand(expanded_size), to_expand2.expand(expanded_size)); | 
 | } | 
 |  | 
 | std::tuple<Tensor, Tensor, Tensor> expand_outplace(const Tensor &to_expand1, | 
 |                                                    const Tensor &to_expand2, | 
 |                                                    const Tensor &to_expand3) { | 
 |   if (to_expand1.sizes().equals(to_expand2.sizes()) && to_expand1.sizes().equals(to_expand3.sizes())) { | 
 |     return std::make_tuple(to_expand1, to_expand2, to_expand3); | 
 |   } | 
 |  | 
 |   auto expanded_size12 = infer_size2(to_expand1.sizes(), to_expand2.sizes()); | 
 |   auto expanded_size = infer_size2(expanded_size12, to_expand3.sizes()); | 
 |   return std::make_tuple(to_expand1.expand(expanded_size), to_expand2.expand(expanded_size), to_expand3.expand(expanded_size)); | 
 | } | 
 |  | 
 | inline std::tuple<Tensor> expand_size(const Tensor &to_expand, IntList sizes) { | 
 |   if(to_expand.sizes().equals(sizes)) { | 
 |     return std::make_tuple(to_expand); | 
 |   } | 
 |  | 
 |   return std::make_tuple(to_expand.expand(sizes)); | 
 | } | 
 |  | 
 | } |