[iOS GPU][BE][2/n] Remove unused APIs (#60281)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60281

1. REmove unused APIs from MPSImageUtils.
2. Move tensor related APIs from MetalUtils to MetalTensorUtils. Delete MetalUtils.h/mm
3. Move metal buffer related APIs to MetalContext
ghstack-source-id: 131839559

Test Plan:
1. CircleCI
2. buck test pp-mac

Reviewed By: SS-JIA

Differential Revision: D29232973

fbshipit-source-id: a4c0c848883b8ef615eeb2936c1f3d18cddcb318
diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm
index 2677fe5..f1cbdaf 100644
--- a/aten/src/ATen/native/metal/MetalAten.mm
+++ b/aten/src/ATen/native/metal/MetalAten.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
 #import <ATen/native/metal/MetalContext.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #include <ATen/metal/Context.h>
 #include <torch/script.h>
 
diff --git a/aten/src/ATen/native/metal/MetalContext.h b/aten/src/ATen/native/metal/MetalContext.h
index c459537..ca58eb9 100644
--- a/aten/src/ATen/native/metal/MetalContext.h
+++ b/aten/src/ATen/native/metal/MetalContext.h
@@ -16,5 +16,6 @@
 - (id<MTLComputePipelineState>)specializedPipelineState:(const std::string&)kernel
                                               Constants:(NSArray<NSNumber*>*)
                                                             constants;
+- (id<MTLBuffer>)emptyMTLBuffer:(int64_t) size;
 
 @end
diff --git a/aten/src/ATen/native/metal/MetalContext.mm b/aten/src/ATen/native/metal/MetalContext.mm
index 064e59e..80ee55e 100644
--- a/aten/src/ATen/native/metal/MetalContext.mm
+++ b/aten/src/ATen/native/metal/MetalContext.mm
@@ -65,7 +65,7 @@
 #else
   return false;
 #endif
-  NSError* error = [self compileProgram];
+  NSError* error = [self _compileProgram];
   if (error) {
     std::string compilationError = error.localizedDescription.UTF8String;
     std::string deviceInfo = self.description.UTF8String;
@@ -139,7 +139,22 @@
   return state;
 }
 
-- (NSError*)compileProgram {
+- (id<MTLBuffer>)emptyMTLBuffer:(int64_t) size {
+    TORCH_CHECK(_device);
+    id<MTLBuffer> buffer = [_device newBufferWithLength:size
+                      options:MTLResourceOptionCPUCacheModeWriteCombined];
+    return buffer;
+}
+
+- (NSString*)description {
+  NSString* desc =
+      [NSString stringWithFormat:@"DeviceName: %s, LanguageVersion: %lu",
+                                 _deviceInfo.name.c_str(),
+                                 (unsigned long)_deviceInfo.languageVersion];
+  return desc;
+}
+
+- (NSError*)_compileProgram {
   __block NSError* compilationError = nil;
   static dispatch_once_t onceToken;
   dispatch_once(&onceToken, ^{
@@ -156,12 +171,6 @@
   return compilationError;
 }
 
-- (NSString*)description {
-  NSString* desc =
-      [NSString stringWithFormat:@"DeviceName: %s, LanguageVersion: %lu",
-                                 _deviceInfo.name.c_str(),
-                                 (unsigned long)_deviceInfo.languageVersion];
-  return desc;
-}
+
 
 @end
diff --git a/aten/src/ATen/native/metal/MetalTensorImplStorage.mm b/aten/src/ATen/native/metal/MetalTensorImplStorage.mm
index 91b336c..cd73ba4 100644
--- a/aten/src/ATen/native/metal/MetalTensorImplStorage.mm
+++ b/aten/src/ATen/native/metal/MetalTensorImplStorage.mm
@@ -1,6 +1,6 @@
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
 
 #include <ATen/Utils.h>
@@ -49,7 +49,7 @@
 
 MetalTensorImplStorage::MetalTensorImplStorage(
     const std::vector<int64_t>& sizes)
-    : MetalTensorImplStorage(sizes, compute_strides(sizes)) {}
+    : MetalTensorImplStorage(sizes, computeStrides(sizes)) {}
 
 MetalTensorImplStorage::MetalTensorImplStorage(
     const std::vector<int64_t>& sizes,
diff --git a/aten/src/ATen/native/metal/MetalTensorUtils.h b/aten/src/ATen/native/metal/MetalTensorUtils.h
index aaa1434..318da09 100644
--- a/aten/src/ATen/native/metal/MetalTensorUtils.h
+++ b/aten/src/ATen/native/metal/MetalTensorUtils.h
@@ -1,4 +1,14 @@
 #include <ATen/Tensor.h>
+#include <ATen/native/metal/MetalContext.h>
+#include <ATen/native/metal/MetalCommandBuffer.h>
+#include <ATen/native/metal/MetalTensorImpl.h>
+#include <ATen/native/metal/MetalTensorImplStorage.h>
+
+#if (defined(__ARM_NEON__) || defined(__ARM_NEON))
+typedef float16_t fp16_t;
+#else
+typedef uint16_t fp16_t;
+#endif
 
 namespace at {
 namespace native {
@@ -9,6 +19,57 @@
 uint32_t heightSize(const Tensor& tensor);
 uint32_t widthSize(const Tensor& tensor);
 
+// When copying the result back to a CPU tensor, the memory format becomes NCHW.
+// Thus,we compute the strides based on contiguous memory format.
+static inline std::vector<int64_t> computeStrides(
+    const std::vector<int64_t>& sizes) {
+  const auto dim = sizes.size();
+  std::vector<int64_t> strides(dim, 0);
+  if (dim > 0) {
+    const auto last_idx = dim - 1;
+    strides[last_idx] = 1;
+    for (int64_t i = last_idx - 1; i >= 0; --i) {
+      strides[i] = strides[i + 1] * std::max<int64_t>(sizes[i + 1], 1);
+    }
+  }
+  return strides;
+}
+
+static inline MetalTensorImplStorage& getTensorImplStorage(
+    const at::Tensor& tensor) {
+  using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
+  TORCH_CHECK(tensor.is_metal());
+  MetalTensorImpl* impl =
+      static_cast<MetalTensorImpl*>(tensor.unsafeGetTensorImpl());
+  return impl->unsafe_opaque_handle();
+}
+
+static inline at::Tensor makeTensor(
+    MetalTensorImplStorage&& mt,
+    const TensorOptions& options) {
+  using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
+  auto sizes = mt.sizes(); // sizes is stored in TensorImpl
+  auto strides = mt.strides(); // strides is stored in MetalTensorImpl
+  return detail::make_tensor<MetalTensorImpl>(
+      DispatchKeySet(DispatchKey::Metal),
+      options.dtype(),
+      at::Device(at::kMetal),
+      std::move(mt),
+      std::vector<int64_t>(sizes.begin(), sizes.end()),
+      std::vector<int64_t>(strides.begin(), strides.end()));
+}
+
+static inline MetalCommandBuffer* getCommandBuffer(
+    const Tensor& tensor) {
+  TORCH_CHECK(tensor.is_metal());
+  auto implStorage = getTensorImplStorage(tensor);
+  MetalCommandBuffer* cmdBuffer = implStorage.texture()->commandBuffer();
+  if (!cmdBuffer || !cmdBuffer.valid) {
+    cmdBuffer = [MetalCommandBuffer currentBuffer];
+  }
+  return cmdBuffer;
+}
+
 } // namespace metal
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/metal/MetalUtils.h b/aten/src/ATen/native/metal/MetalUtils.h
deleted file mode 100644
index a94fce6..0000000
--- a/aten/src/ATen/native/metal/MetalUtils.h
+++ /dev/null
@@ -1,97 +0,0 @@
-#include <ATen/Tensor.h>
-#include <ATen/native/metal/MetalContext.h>
-#include <ATen/native/metal/MetalCommandBuffer.h>
-#include <ATen/native/metal/MetalTensorImpl.h>
-#include <ATen/native/metal/MetalTensorImplStorage.h>
-#include <vector>
-
-#if (defined(__ARM_NEON__) || defined(__ARM_NEON))
-typedef float16_t fp16_t;
-#else
-typedef uint16_t fp16_t;
-#endif
-
-namespace at {
-namespace native {
-namespace metal {
-
-std::vector<fp16_t> Fp32ToFp16(const std::vector<float>& src);
-std::vector<float> Fp16ToFp32(const std::vector<fp16_t>& src);
-
-std::vector<float> NCHWToNC4(
-    const float* src,
-    const std::vector<int64_t>& sizes);
-std::vector<float> NC4ToNCHW(
-    const float* src,
-    const std::vector<int64_t>& sizes);
-
-// When copying the result back to a CPU tensor, the memory format becomes NCHW.
-// Thus,we compute the strides based on contiguous memory format.
-static inline std::vector<int64_t> compute_strides(
-    const std::vector<int64_t>& sizes) {
-  const auto dim = sizes.size();
-  std::vector<int64_t> strides(dim, 0);
-  if (dim > 0) {
-    const auto last_idx = dim - 1;
-    strides[last_idx] = 1;
-    for (int i = last_idx - 1; i >= 0; --i) {
-      strides[i] = strides[i + 1] * std::max<int64_t>(sizes[i + 1], 1);
-    }
-  }
-  return strides;
-}
-
-static inline MetalTensorImplStorage& getTensorImplStorage(
-    const at::Tensor& tensor) {
-  using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
-  TORCH_CHECK(tensor.is_metal());
-  MetalTensorImpl* impl =
-      static_cast<MetalTensorImpl*>(tensor.unsafeGetTensorImpl());
-  return impl->unsafe_opaque_handle();
-}
-
-static inline at::Tensor makeTensor(
-    MetalTensorImplStorage&& mt,
-    const TensorOptions& options) {
-  using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
-  auto sizes = mt.sizes(); // sizes is stored in TensorImpl
-  auto strides = mt.strides(); // strides is stored in MetalTensorImpl
-  return detail::make_tensor<MetalTensorImpl>(
-      DispatchKeySet(DispatchKey::Metal),
-      options.dtype(),
-      at::Device(at::kMetal),
-      std::move(mt),
-      std::vector<int64_t>(sizes.begin(), sizes.end()),
-      std::vector<int64_t>(strides.begin(), strides.end()));
-}
-
-static inline MetalCommandBuffer* getCommandBufferFromTensor(
-    const Tensor& tensor) {
-  TORCH_CHECK(tensor.is_metal());
-  auto implStorage = getTensorImplStorage(tensor);
-  MetalCommandBuffer* cmdBuffer = implStorage.texture()->commandBuffer();
-  if (!cmdBuffer || !cmdBuffer.valid) {
-    cmdBuffer = [MetalCommandBuffer currentBuffer];
-  }
-  return cmdBuffer;
-}
-
-template<typename T>
-id<MTLBuffer>makeMTLBuffer(const std::vector<T>& src) {
-    id<MTLBuffer> buffer = [[MetalContext sharedInstance].device
-          newBufferWithLength:src.size() * sizeof(T)
-                      options:MTLResourceOptionCPUCacheModeWriteCombined];
-    memcpy(buffer.contents, src.data(), src.size() * sizeof(T));
-    return buffer;
-}
-
-static inline id<MTLBuffer>makeMTLBuffer(int64_t bytes) {
-    id<MTLBuffer> buffer = [[MetalContext sharedInstance].device
-          newBufferWithLength:bytes
-                      options:MTLResourceOptionCPUCacheModeWriteCombined];
-    return buffer;
-}
-
-} // namespace metal
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/metal/MetalUtils.mm b/aten/src/ATen/native/metal/MetalUtils.mm
deleted file mode 100644
index a082c15..0000000
--- a/aten/src/ATen/native/metal/MetalUtils.mm
+++ /dev/null
@@ -1,100 +0,0 @@
-#import <ATen/native/metal/MetalUtils.h>
-#import <ATen/native/metal/MetalContext.h>
-#import <Accelerate/Accelerate.h>
-
-namespace at {
-namespace native {
-namespace metal {
-
-std::vector<fp16_t> Fp32ToFp16(const std::vector<float>& src) {
-    unsigned long count = src.size();
-    std::vector<fp16_t> output(count, 0);
-    vImage_Buffer float32{(void*)src.data(), 1, count, count * sizeof(float)};
-    vImage_Buffer float16{(void*)output.data(), 1, count, count * sizeof(fp16_t)};
-    if (vImageConvert_PlanarFtoPlanar16F(&float32, &float16, 0) !=
-        kvImageNoError) {
-      TORCH_CHECK(false);
-    }
-  return output;
-}
-
-std::vector<float> Fp16ToFp32(const std::vector<fp16_t>& src) {
-  unsigned long count = src.size();
-  std::vector<float> output(count, 0);
-  vImage_Buffer float16{(void*)src.data(), 1, count, count * sizeof(fp16_t)};
-  vImage_Buffer float32{(void*)output.data(), 1, count, count * sizeof(float)};
-  if (vImageConvert_Planar16FtoPlanarF(&float16, &float32, 0) !=
-      kvImageNoError) {
-    TORCH_CHECK(false);
-  }
-  return output;
-}
-
-std::vector<float> NCHWToNC4(
-    const float* src,
-    const std::vector<int64_t>& sizes) {
-  int64_t N = sizes[0];
-  int64_t C = sizes[1];
-  int64_t H = sizes[2];
-  int64_t W = sizes[3];
-  int64_t src_image_count = C * H * W;
-  int64_t src_count = N * src_image_count;
-  int64_t slices = (C + 3) / 4;
-  int64_t numComponents = C < 3 ? C : 4;
-  int64_t dst_image_count = slices * numComponents * W * H;
-  int64_t dst_count = N * dst_image_count;
-  std::vector<float> output(dst_count, 0.0f);
-  for (int n = 0; n < N; ++n) {
-    int64_t src_image = n * src_image_count;
-    int64_t dst_image = n * dst_image_count;
-    for (int i = 0; i < slices; ++i) {
-      int64_t slice = i * W * H * numComponents;
-      for (int j = 0; j < W * H; ++j) {
-        for (int k = 0; k < numComponents; ++k) {
-          int ii = src_image + slice + k * W * H + j;
-          int oi = dst_image + slice + j * numComponents + k;
-          if (k < C && ii < src_count) {
-            output[oi] = src[ii];
-          }
-        }
-      }
-    }
-  }
-  return output;
-}
-
-std::vector<float> NC4ToNCHW(
-    const float* src,
-    const std::vector<int64_t>& sizes) {
-  int64_t N = sizes[0];
-  int64_t C = sizes[1];
-  int64_t H = sizes[2];
-  int64_t W = sizes[3];
-  int64_t slices = (C + 3) / 4;
-  int64_t numComponents = C < 3 ? C : 4;
-  int64_t src_image_count = slices * numComponents * W * H;
-  int64_t dst_image_count = C * H * W;
-  int64_t dst_count = N * dst_image_count;
-  std::vector<float> output(dst_count, 0.0f);
-  for (int n = 0; n < N; ++n) {
-    int64_t src_image = n * src_image_count;
-    int64_t dst_image = n * dst_image_count;
-    for (int i = 0; i < slices; ++i) {
-      int64_t slice = i * W * H * numComponents;
-      for (int j = 0; j < numComponents; ++j) {
-        for (int k = 0; k < W * H; ++k) {
-          int ii = src_image + slice + k * numComponents + j;
-          int oi = dst_image + slice + j * W * H + k;
-          if (j < C && oi < dst_count) {
-            output[oi] = src[ii];
-          }
-        }
-      }
-    }
-  }
-  return output;
-}
-
-}
-}
-}
diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm
index 2d380dd..93218bd 100644
--- a/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm
+++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm
@@ -1,4 +1,4 @@
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
 #import <ATen/native/metal/MetalContext.h>
diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.h b/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.h
index 4ef2247..53065e4 100644
--- a/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.h
+++ b/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.h
@@ -1,7 +1,7 @@
 #import <ATen/Tensor.h>
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 
 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
 
@@ -11,13 +11,8 @@
 
 MPSImage* createStaticImage(IntArrayRef sizes);
 MPSImage* createStaticImage(
-    const fp16_t* src,
-    const IntArrayRef sizes);
-MPSImage* createStaticImage(
     const float* src,
     const IntArrayRef sizes);
-MPSImage* createStaticImage(const at::Tensor& tensor);
-MPSImage* createStaticImage(MPSImage* image);
 MPSImage* createStaticImage(
     MPSTemporaryImage* image,
     MetalCommandBuffer* buffer,
@@ -37,9 +32,6 @@
 void copyToHost(float* dst, MPSImage* image);
 void copyToMetalBuffer(MetalCommandBuffer* buffer, id<MTLBuffer> dst, MPSImage* image);
 
-std::vector<fp16_t> staticImageToFp16Array(MPSImage* image);
-at::Tensor staticImageToTensor(MPSImage* image);
-
 static inline MPSImage* imageFromTensor(const Tensor& tensor) {
   TORCH_CHECK(tensor.is_metal());
   using MetalTensorImplStorage = at::native::metal::MetalTensorImplStorage;
@@ -63,7 +55,7 @@
   std::vector<int64_t> imageSize(4, 1);
   int64_t index = 3;
   int64_t batch = 1;
-  for (int i = sizes.size() - 1; i >= 0; i--) {
+  for (int64_t i = sizes.size() - 1; i >= 0; i--) {
     if (index != 0) {
         imageSize[index] = sizes[i];
       index--;
diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm b/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm
index 8d2f171..817672a 100644
--- a/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm
+++ b/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm
@@ -1,4 +1,4 @@
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -24,39 +24,6 @@
                           imageDescriptor:desc];
 }
 
-MPSImage* createStaticImage(const fp16_t* src, IntArrayRef sizes) {
-  int64_t N = sizes[0];
-  int64_t C = sizes[1];
-  int64_t H = sizes[2];
-  int64_t W = sizes[3];
-  MPSImageDescriptor* desc = [MPSImageDescriptor
-      imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16
-                                 width:W
-                                height:H
-                       featureChannels:C
-                        numberOfImages:N
-                                 usage:MTLTextureUsageShaderRead |
-                                 MTLTextureUsageShaderWrite];
-  MPSImage* image =
-      [[MPSImage alloc] initWithDevice:[MetalContext sharedInstance].device
-                       imageDescriptor:desc];
-
-  int64_t slices = (C + 3) / 4 * N;
-  int64_t numComponents = image.featureChannels < 3 ? image.featureChannels : 4;
-  int64_t bytesPerRow = W * numComponents * sizeof(fp16_t);
-  uint8_t* ptr = (uint8_t*)src;
-  for (int i = 0; i < slices; ++i) {
-    [image.texture replaceRegion:MTLRegionMake2D(0, 0, W, H)
-                     mipmapLevel:0
-                           slice:i
-                       withBytes:ptr
-                     bytesPerRow:bytesPerRow
-                   bytesPerImage:0];
-    ptr += H * bytesPerRow;
-  }
-  return image;
-}
-
 MPSImage* createStaticImage(const float* src, IntArrayRef sizes) {
   int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float);
   id<MTLBuffer> buff = [[MetalContext sharedInstance].device
@@ -88,36 +55,6 @@
   return output;
 }
 
-MPSImage* createStaticImage(const at::Tensor& tensor) {
-  TORCH_CHECK(tensor.device().is_cpu());
-  TORCH_CHECK(tensor.dim() == 4);
-  auto contiguousTensor = tensor.contiguous();
-  float* src = tensor.data_ptr<float>();
-  std::vector<int64_t> sizes = tensor.sizes().vec();
-  auto c4 = NCHWToNC4(src, sizes);
-  auto c4fp16 = Fp32ToFp16(c4);
-  return createStaticImage(c4fp16.data(), sizes);
-}
-
-MPSImage* createStaticImage(MPSImage* image) {
-  MPSImage* Y = createStaticImage([image sizes]);
-  MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer];
-  id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
-  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
-      pipelineState:mpscnn::kernelFor(image, "copy", "copy_nonarray")];
-  [encoder setComputePipelineState:state];
-  [encoder setTexture:[image texture] atIndex:0];
-  [encoder setTexture:[Y texture] atIndex:1];
-
-  const auto& launchParams =
-      mpscnn::spatialPointwiseKernelLaunchParams(state, image);
-  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
-          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
-  [encoder endEncoding];
-  [cb commit];
-  return Y;
-}
-
 MPSImage* createStaticImage(
     MPSTemporaryImage* image,
     MetalCommandBuffer* buffer,
@@ -276,44 +213,6 @@
   [encoder endEncoding];
 }
 
-std::vector<fp16_t> staticImageToFp16Array(MPSImage* image) {
-  if (image.pixelFormat == MTLPixelFormatR16Float ||
-      image.pixelFormat == MTLPixelFormatRG16Float ||
-      image.pixelFormat == MTLPixelFormatRGBA16Float) {
-    int64_t slices = (image.featureChannels + 3) / 4;
-    int64_t C = image.featureChannels < 3 ? image.featureChannels : slices * 4;
-    int64_t numComponents =
-        image.featureChannels < 3 ? image.featureChannels : 4;
-    int64_t count = image.width * image.height * image.numberOfImages * C;
-    std::vector<fp16_t> output(count, 0);
-    int64_t bytesPerRow = image.width * numComponents * sizeof(fp16_t);
-    uint8_t* buffer = (uint8_t*)output.data();
-    for (int i = 0; i < slices * image.numberOfImages; ++i) {
-      [image.texture getBytes:buffer
-                  bytesPerRow:bytesPerRow
-                bytesPerImage:0
-                   fromRegion:MTLRegionMake2D(0, 0, image.width, image.height)
-                  mipmapLevel:0
-                        slice:i];
-      buffer += image.height * bytesPerRow;
-    }
-    return output;
-  }
-  TORCH_CHECK(
-      false, "Copy to float buffer failed: The pixel format didn't match");
-}
-
-at::Tensor staticImageToTensor(MPSImage* image) {
-  auto outputSize = [image sizes];
-  std::vector<fp16_t> fp16Array = staticImageToFp16Array(image);
-  auto fp32 = metal::Fp16ToFp32(fp16Array);
-  std::vector<float> fp32_nchw = metal::NC4ToNCHW(fp32.data(), outputSize);
-  auto tensor = at::empty(outputSize);
-  int64_t size_bytes = c10::multiply_integers(outputSize) * sizeof(float);
-  memcpy(tensor.data_ptr(), fp32_nchw.data(), size_bytes);
-  return tensor;
-}
-
 }
 }
 }
diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm b/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm
index 10f3967..1d94bdd 100644
--- a/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm
+++ b/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm
@@ -1,5 +1,5 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
index 95e8f16..57c1c8a 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
@@ -2,7 +2,6 @@
 #define MPSCNNTests_h
 
 bool test_synchronization();
-bool test_nchw_to_nc4_cpu();
 bool test_copy_nchw_to_metal();
 bool test_conv2d();
 bool test_depthwiseConv();
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
index f21fb13..5df31eb 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
@@ -1,5 +1,5 @@
 #import <ATen/ATen.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
 #import <ATen/native/metal/mpscnn/MPSImageUtils.h>
 #import <ATen/native/metal/mpscnn/tests/MPSCNNTests.h>
@@ -86,13 +86,7 @@
     }
     std::cout << str << std::endl;
   };
-  if (tensor.is_metal()) {
-    MPSImage* image = at::native::metal::imageFromTensor(tensor);
-    auto t = at::native::metal::staticImageToTensor(image);
-    print(t);
-  } else {
-    print(tensor);
-  }
+  print(tensor);
 }
 
 }
@@ -111,29 +105,6 @@
   });
 }
 
-bool test_nchw_to_nc4_cpu() {
-  bool result = true;
-  for (int i = 0; i < ITER_COUNT; ++i) {
-    int64_t N = rand(1, 24);
-    int64_t C = rand(1, 48);
-    int64_t H = rand(1, 320);
-    int64_t W = rand(1, 320);
-    __block std::vector<int64_t> size{N, C, H, W};
-    bool b = TEST(size, __PRETTY_FUNCTION__, ^bool {
-      auto t = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
-      const auto len = c10::multiply_integers(std::begin(size), std::end(size));
-      auto buf =
-          std::vector<float>{t.data_ptr<float>(), t.data_ptr<float>() + len};
-      auto c4 = NCHWToNC4((float*)t.data_ptr<float>(), t.sizes().vec());
-      auto n4 = NC4ToNCHW((float*)c4.data(), t.sizes().vec());
-      return n4 == buf;
-    });
-    if (!b) {
-      result = false;
-    }
-  }
-  return result;
-}
 
 bool test_copy_nchw_to_metal() {
   __block std::vector<int64_t> size{1, 3, 224, 224};
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm
index eb36e4d..5c9ecfb 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm
@@ -6,13 +6,12 @@
 #import <ATen/native/metal/mpscnn/tests/MetalOpTestRunner.h>
 
 @implementation MetalOpTestRunner {
-  NSMutableDictionary *_tests;
+  NSMutableDictionary* _tests;
 }
 
-+ (instancetype)sharedInstance
-{
++ (instancetype)sharedInstance {
   static dispatch_once_t onceToken;
-  static MetalOpTestRunner *instance = nil;
+  static MetalOpTestRunner* instance = nil;
   dispatch_once(&onceToken, ^{
     instance = [MetalOpTestRunner new];
   });
@@ -29,9 +28,11 @@
 
 - (void)registerTests {
   _tests = [NSMutableDictionary dictionary];
-#define REG_TEST(arg1, arg2) _tests[@arg1] = ^BOOL(void){return arg2();}
+#define REG_TEST(arg1, arg2)    \
+  _tests[@arg1] = ^BOOL(void) { \
+    return arg2();              \
+  }
   REG_TEST("test_synchronization", test_synchronization);
-  REG_TEST("test_nchw_to_nc4_cpu", test_nchw_to_nc4_cpu);
   REG_TEST("test_copy_nchw_to_metal", test_copy_nchw_to_metal);
   REG_TEST("test_conv2d", test_conv2d);
   REG_TEST("test_depthwiseConv", test_depthwiseConv);
@@ -81,7 +82,7 @@
   REG_TEST("test_reflection_pad2d", test_reflection_pad2d);
 }
 
-- (NSDictionary *) tests {
+- (NSDictionary*)tests {
   return _tests;
 }
 
diff --git a/aten/src/ATen/native/metal/ops/MetalAddmm.mm b/aten/src/ATen/native/metal/ops/MetalAddmm.mm
index c023d91..8086b17 100644
--- a/aten/src/ATen/native/metal/ops/MetalAddmm.mm
+++ b/aten/src/ATen/native/metal/ops/MetalAddmm.mm
@@ -3,7 +3,7 @@
 #import <ATen/native/metal/MetalPrepackOpContext.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h>
@@ -45,7 +45,7 @@
   auto packedWeights = weight_.contiguous(c10::MemoryFormat::ChannelsLast);
   MetalTensorImplStorage mt{{params.N, params.OC}};
   SmallVector<int64_t, 4> textureSize = {params.N, params.OC, 1, 1};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input_);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input_);
   mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   float* w = packedWeights.data_ptr<float>();
@@ -101,7 +101,7 @@
   }
   MetalTensorImplStorage mt{{params.N, params.OC}};
   SmallVector<int64_t, 4> textureSize = {params.N, params.OC, 1, 1};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input_);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input_);
   mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer);
   MPSImage* Y1 = mt.texture()->image();
   // HACK alert:
diff --git a/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm b/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm
index 294913d..97294a1 100644
--- a/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm
+++ b/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm
@@ -76,8 +76,8 @@
     return makeTensor({outputSize.vec()}, input1.options());
   }
   MetalTensorImplStorage mt{outputSize.vec()};
-  MetalCommandBuffer* cb1 = getCommandBufferFromTensor(input1);
-  MetalCommandBuffer* cb2 = getCommandBufferFromTensor(input2);
+  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
+  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
   TORCH_CHECK(
       [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
   mt.texture()->allocateTemporaryStorage(outputSize, cb1);
@@ -117,8 +117,8 @@
   if(c10::multiply_integers(outputSize) == 0){
       return input1;
   }
-  MetalCommandBuffer* cb1 = getCommandBufferFromTensor(input1);
-  MetalCommandBuffer* cb2 = getCommandBufferFromTensor(input2);
+  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
+  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
   TORCH_CHECK(
       [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
   MPSImage* Y = createTemporaryImage(cb1, outputSize.vec());
@@ -159,8 +159,8 @@
       return makeTensor({outputSize.vec()}, input1.options());
   }
   MetalTensorImplStorage mt{outputSize.vec()};
-  MetalCommandBuffer* cb1 = getCommandBufferFromTensor(input1);
-  MetalCommandBuffer* cb2 = getCommandBufferFromTensor(input2);
+  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
+  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
   TORCH_CHECK(
       [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
   mt.texture()->allocateTemporaryStorage(outputSize, cb1);
@@ -192,8 +192,8 @@
   if(c10::multiply_integers(outputSize) == 0){
     return input1;
   }
-  MetalCommandBuffer* cb1 = getCommandBufferFromTensor(input1);
-  MetalCommandBuffer* cb2 = getCommandBufferFromTensor(input2);
+  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
+  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
   TORCH_CHECK(
       [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
   MPSImage* Y = createTemporaryImage(cb1, outputSize.vec());
diff --git a/aten/src/ATen/native/metal/ops/MetalChunk.mm b/aten/src/ATen/native/metal/ops/MetalChunk.mm
index 89d8d36..3da3e68 100644
--- a/aten/src/ATen/native/metal/ops/MetalChunk.mm
+++ b/aten/src/ATen/native/metal/ops/MetalChunk.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -28,7 +28,7 @@
   std::vector<Tensor> splits(num_splits);
   int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
   MPSImage* X = imageFromTensor(input);
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   auto outputSize1 = {input.size(0), split_size, input.size(2), input.size(3)};
   auto outputSize2 = {input.size(0), last_split_size, input.size(2), input.size(3)};
   MetalTensorImplStorage mt1(outputSize1);
diff --git a/aten/src/ATen/native/metal/ops/MetalClamp.mm b/aten/src/ATen/native/metal/ops/MetalClamp.mm
index bf96813..23ed28d 100644
--- a/aten/src/ATen/native/metal/ops/MetalClamp.mm
+++ b/aten/src/ATen/native/metal/ops/MetalClamp.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
 #import <ATen/native/metal/mpscnn/MPSImageUtils.h>
@@ -15,7 +15,7 @@
 Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) {
   TORCH_CHECK(input.is_metal());
   MPSImage* X = imageFromTensor(input);
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   MPSImage* Y = createTemporaryImage(commandBuffer, input.sizes().vec());
   float min = min_val.toFloat();
   float max = max_val.toFloat();
diff --git a/aten/src/ATen/native/metal/ops/MetalConcat.mm b/aten/src/ATen/native/metal/ops/MetalConcat.mm
index 7e143da..2b34bc4 100644
--- a/aten/src/ATen/native/metal/ops/MetalConcat.mm
+++ b/aten/src/ATen/native/metal/ops/MetalConcat.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -18,13 +18,13 @@
 
 Tensor cat_batch(const TensorList tensors, MetalTensorImplStorage& mt) {
   at::Tensor tensor = tensors[0];
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(tensor);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor);
   MPSImage* Y = mt.texture()->image();
   ushort cat_dim4_pointer = 0;
   for (int i = 0; i < tensors.size(); ++i) {
     const auto& t = tensors[i];
     MPSImage* X = imageFromTensor(t);
-    MetalCommandBuffer* Xcb = getCommandBufferFromTensor(t);
+    MetalCommandBuffer* Xcb = getCommandBuffer(t);
     TORCH_CHECK(
         [commandBuffer isEqual:Xcb],
         @"inputs have different Metal command buffers");
@@ -58,13 +58,13 @@
 
 Tensor cat_feature(const TensorList tensors, MetalTensorImplStorage& mt) {
   at::Tensor tensor = tensors[0];
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(tensor);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor);
   MPSImage* Y = mt.texture()->image();
   ushort channel_offset = 0;
   for (int i = 0; i < tensors.size(); ++i) {
     const auto& t = tensors[i];
     MPSImage* X = imageFromTensor(t);
-    MetalCommandBuffer* Xcb = getCommandBufferFromTensor(t);
+    MetalCommandBuffer* Xcb = getCommandBuffer(t);
     TORCH_CHECK(
         [commandBuffer isEqual:Xcb],
         @"inputs have different Metal command buffers");
@@ -124,7 +124,7 @@
       "Metal cat is implemented only for batch dimension");
   int64_t cat_dim_size = 0;
   at::Tensor tensor = tensors[0];
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(tensor);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor);
   for (int i = 0; i < tensors.size(); ++i) {
     const auto& t = tensors[i];
     TORCH_CHECK(t.dim() == 4, "Metal cat expects 4 dimensional inputs");
diff --git a/aten/src/ATen/native/metal/ops/MetalConvolution.mm b/aten/src/ATen/native/metal/ops/MetalConvolution.mm
index 98fc87e..c726382 100644
--- a/aten/src/ATen/native/metal/ops/MetalConvolution.mm
+++ b/aten/src/ATen/native/metal/ops/MetalConvolution.mm
@@ -1,6 +1,6 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
 #import <ATen/native/metal/mpscnn/MPSCNNConvOp.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -42,7 +42,7 @@
                                      bias:b
                              neuronFilter:NeuronType::None];
   MetalTensorImplStorage mt{outputSize};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y];
@@ -79,7 +79,7 @@
     };
   }
   MetalTensorImplStorage mt{outputSize};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
   MPSImage* Y1 = mt.texture()->image();
   [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1];
diff --git a/aten/src/ATen/native/metal/ops/MetalCopy.mm b/aten/src/ATen/native/metal/ops/MetalCopy.mm
index a51eab9..b6c783b 100644
--- a/aten/src/ATen/native/metal/ops/MetalCopy.mm
+++ b/aten/src/ATen/native/metal/ops/MetalCopy.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -19,7 +19,7 @@
   if (X && !X.isTemporaryImage) {
     return input;
   }
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   auto&& sizes = [X sizes];
   MetalTensorImplStorage mt{sizes};
   mt.texture()->setCommandBuffer(commandBuffer);
diff --git a/aten/src/ATen/native/metal/ops/MetalHardswish.mm b/aten/src/ATen/native/metal/ops/MetalHardswish.mm
index 66bf362..f446b4c 100644
--- a/aten/src/ATen/native/metal/ops/MetalHardswish.mm
+++ b/aten/src/ATen/native/metal/ops/MetalHardswish.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -17,7 +17,7 @@
 
 Tensor& hardswish_(Tensor& input) {
   MPSImage* X = imageFromTensor(input);
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   IntArrayRef outputSize = input.sizes();
   std::vector<int64_t> imageSize = computeImageSize(outputSize);
   MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
diff --git a/aten/src/ATen/native/metal/ops/MetalNeurons.mm b/aten/src/ATen/native/metal/ops/MetalNeurons.mm
index b095d33..5ecbf2b 100644
--- a/aten/src/ATen/native/metal/ops/MetalNeurons.mm
+++ b/aten/src/ATen/native/metal/ops/MetalNeurons.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -23,7 +23,7 @@
   }
   IntArrayRef textureSize = outputSize;
   MetalTensorImplStorage mt{outputSize.vec()};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   [neuron encodeToCommandBuffer:commandBuffer.buffer
@@ -40,7 +40,7 @@
     return input;
   }
   IntArrayRef textureSize = outputSize;
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   MPSImage* Y = createTemporaryImage(commandBuffer, textureSize);
   [neuron encodeToCommandBuffer:commandBuffer.buffer
                     sourceImage:X
diff --git a/aten/src/ATen/native/metal/ops/MetalPadding.mm b/aten/src/ATen/native/metal/ops/MetalPadding.mm
index 2610790..9a37f7e 100644
--- a/aten/src/ATen/native/metal/ops/MetalPadding.mm
+++ b/aten/src/ATen/native/metal/ops/MetalPadding.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -48,7 +48,7 @@
   }
 
   MPSImage* X = imageFromTensor(input);
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   MetalTensorImplStorage mt{output_size};
   mt.texture()->allocateTemporaryStorage(output_size, commandBuffer);
   MPSImage* Y = mt.texture()->image();
diff --git a/aten/src/ATen/native/metal/ops/MetalPooling.mm b/aten/src/ATen/native/metal/ops/MetalPooling.mm
index 945fc84..db8f8fd 100644
--- a/aten/src/ATen/native/metal/ops/MetalPooling.mm
+++ b/aten/src/ATen/native/metal/ops/MetalPooling.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -60,7 +60,7 @@
                  .y = mpscnn::computeMPSAlignOffset(kernel_size[1], padding[1]),
                  .z = 0}];
   MetalTensorImplStorage mt{IntArrayRef(outputSize).vec()};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   [pool encodeToCommandBuffer:commandBuffer.buffer
@@ -93,7 +93,7 @@
                    .z = 0}];
 
   MetalTensorImplStorage mt{IntArrayRef(outputSize).vec()};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   [pool encodeToCommandBuffer:commandBuffer.buffer
diff --git a/aten/src/ATen/native/metal/ops/MetalReduce.mm b/aten/src/ATen/native/metal/ops/MetalReduce.mm
index 29a4bd9..5c3129b 100644
--- a/aten/src/ATen/native/metal/ops/MetalReduce.mm
+++ b/aten/src/ATen/native/metal/ops/MetalReduce.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
 #import <ATen/native/metal/mpscnn/MPSImageUtils.h>
@@ -40,7 +40,7 @@
     // TODO: [T87340633] Support reducing the batch dimension
     TORCH_CHECK(imageSize[0] == 1);
     auto mask = make_dim_mask(dims, input.dim());
-    MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+    MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
     MPSImage* Y = nil;
     for (int dim : dims) {
       imageSize[dim] = 1;
diff --git a/aten/src/ATen/native/metal/ops/MetalReshape.mm b/aten/src/ATen/native/metal/ops/MetalReshape.mm
index 64b3c8d..28dbae2 100644
--- a/aten/src/ATen/native/metal/ops/MetalReshape.mm
+++ b/aten/src/ATen/native/metal/ops/MetalReshape.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -31,7 +31,7 @@
     return makeTensor({inferred_size, stride_value}, input.options());
   }
   MPSImage* X = imageFromTensor(input);
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   MetalTensorImplStorage mt{inferred_size, stride_value};
   mt.texture()->allocateTemporaryStorage(inferred_size, commandBuffer);
   MPSImage* Y = mt.texture()->image();
diff --git a/aten/src/ATen/native/metal/ops/MetalSoftmax.mm b/aten/src/ATen/native/metal/ops/MetalSoftmax.mm
index 181d1ff..bd22a0a 100644
--- a/aten/src/ATen/native/metal/ops/MetalSoftmax.mm
+++ b/aten/src/ATen/native/metal/ops/MetalSoftmax.mm
@@ -2,7 +2,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
 #import <ATen/native/metal/mpscnn/MPSImageUtils.h>
@@ -39,7 +39,7 @@
   // https://developer.apple.com/documentation/metalperformanceshaders/mpscnnsoftmax?changes=_1&language=objc
   T* softmax = [[T alloc] initWithDevice:[MetalContext sharedInstance].device];
   MetalTensorImplStorage mt{newSize};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input_);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input_);
   mt.texture()->allocateTemporaryStorage(newSize, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   [softmax encodeToCommandBuffer:commandBuffer.buffer
diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.mm b/aten/src/ATen/native/metal/ops/MetalTranspose.mm
index 3adb0e0..a7017fb 100644
--- a/aten/src/ATen/native/metal/ops/MetalTranspose.mm
+++ b/aten/src/ATen/native/metal/ops/MetalTranspose.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -14,6 +14,16 @@
 namespace native {
 namespace metal {
 
+// TODO: Move this function to MetalContext
+template<typename T>
+id<MTLBuffer> _makeMTLBuffer(const std::vector<T>& src) {
+    id<MTLBuffer> buffer = [[MetalContext sharedInstance].device
+          newBufferWithLength:src.size() * sizeof(T)
+                      options:MTLResourceOptionCPUCacheModeWriteCombined];
+    memcpy(buffer.contents, src.data(), src.size() * sizeof(T));
+    return buffer;
+}
+
 Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) {
   TORCH_CHECK(input.is_metal());
   auto ndims = input.dim();
@@ -27,7 +37,7 @@
   auto outputSizes = input.sizes().vec();
   std::swap(outputSizes[dim0], outputSizes[dim1]);
   MPSImage* X = imageFromTensor(input);
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   if (input.dim() == 2) {
     MetalTensorImplStorage mt{outputSizes};
     mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer);
@@ -40,9 +50,9 @@
     auto output = makeTensor(std::move(mt), input.options());
     return output;
   } else {
-    id<MTLBuffer> sizeBuf1 = makeMTLBuffer<ushort>(
+    id<MTLBuffer> sizeBuf1 = _makeMTLBuffer<ushort>(
         std::vector<ushort>{input.sizes().begin(), input.sizes().end()});
-    id<MTLBuffer> sizeBuf2 = makeMTLBuffer<ushort>(
+    id<MTLBuffer> sizeBuf2 = _makeMTLBuffer<ushort>(
         std::vector<ushort>{outputSizes.begin(), outputSizes.end()});
     MetalTensorImplStorage mt{outputSizes};
     mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer);
diff --git a/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm b/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm
index 049aefe..c5c008b 100644
--- a/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm
+++ b/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm
@@ -1,7 +1,7 @@
 #import <ATen/native/metal/MetalCommandBuffer.h>
 #import <ATen/native/metal/MetalTensorImpl.h>
 #import <ATen/native/metal/MetalTensorImplStorage.h>
-#import <ATen/native/metal/MetalUtils.h>
+#import <ATen/native/metal/MetalTensorUtils.h>
 #import <ATen/native/metal/MetalContext.h>
 #import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
@@ -46,7 +46,7 @@
   }
   MPSImage* X = imageFromTensor(input);
   MetalTensorImplStorage mt{outputSizes};
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
   mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer);
   MPSImage* Y = mt.texture()->image();
   if (@available(iOS 11.0, *)) {