blob: a7732689526f1d77ce2fb371761e851efae4936c [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/common.h"
#import <Metal/Metal.h>
#include <Availability.h>
#include <utility>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/status.h"
// Compile-time message: print define name and value.
#define VALUE_TO_STRING(x) #x
#define VALUE(x) VALUE_TO_STRING(x)
#define VAR_NAME_VALUE(var) #var "=" VALUE(var)
namespace tflite {
namespace gpu {
namespace metal {
id<MTLDevice> GetBestSupportedMetalDevice() { return MTLCreateSystemDefaultDevice(); }
absl::Status CreateComputeProgram(id<MTLDevice> device, NSString* code, NSString* functionName,
NSDictionary<NSString*, NSString*>* macros,
id<MTLComputePipelineState>* program) {
MTLCompileOptions* options = [[MTLCompileOptions alloc] init];
// Runtime checks for the iOS version independently of minimum target iOS.
if (@available(macOS 10.14, iOS 12.0, tvOS 12.0, *)) {
[options setLanguageVersion:MTLLanguageVersion2_1];
} else if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
[options setLanguageVersion:MTLLanguageVersion2_0];
} else if (@available(macOS 10.12, iOS 10.0, tvOS 10.0, *)) {
[options setLanguageVersion:MTLLanguageVersion1_2];
} else if (@available(macOS 10.11, iOS 9.0, tvOS 9.0, *)) {
[options setLanguageVersion:MTLLanguageVersion1_1];
}
#if (defined(__MAC_10_11) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_11) || \
(defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0) || \
(defined(__TVOS_9_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_9_0)
// Minimum target OS version is able to support Metal.
#else
#pragma message(VAR_NAME_VALUE(__MAC_OS_X_VERSION_MIN_REQUIRED))
#pragma message(VAR_NAME_VALUE(__IPHONE_OS_VERSION_MIN_REQUIRED))
#pragma message(VAR_NAME_VALUE(__TV_OS_VERSION_MIN_REQUIRED))
// NOLINTBEGIN
#error \
"The Metal delegate is not supported on current target SDK. Minimum supported os: iOS/tvOS 9.0, macOS 10.11"
// NOLINTEND
#endif
[options setFastMathEnabled:YES];
[options setPreprocessorMacros:macros];
NSError* error = nil;
id<MTLLibrary> library = [device newLibraryWithSource:code options:options error:&error];
if (!library) {
NSString* errorString =
[NSString stringWithFormat:@"newLibraryWithSource: %@", [error localizedDescription]];
return absl::InternalError([errorString UTF8String]);
}
id<MTLFunction> function = [library newFunctionWithName:functionName];
if (!function) {
NSString* errorString =
[NSString stringWithFormat:@"newFunctionWithName: %@", [error localizedDescription]];
return absl::InternalError([errorString UTF8String]);
}
*program = [device newComputePipelineStateWithFunction:function error:&error];
if (!program) {
NSString* errorString =
[NSString stringWithFormat:@"newComputePipelineStateWithFunction error: %@",
[error localizedDescription]];
return absl::InternalError([errorString UTF8String]);
}
return absl::OkStatus();
}
int PixelFormatToSizeInBytes(MTLPixelFormat pixel_format) {
if (pixel_format == MTLPixelFormatRGBA32Uint ||
pixel_format == MTLPixelFormatRGBA32Sint ||
pixel_format == MTLPixelFormatRGBA32Float) {
return 16;
} else if (pixel_format == MTLPixelFormatRGBA16Unorm ||
pixel_format == MTLPixelFormatRGBA16Snorm ||
pixel_format == MTLPixelFormatRGBA16Uint ||
pixel_format == MTLPixelFormatRGBA16Sint ||
pixel_format == MTLPixelFormatRGBA16Float) {
return 8;
} else if (pixel_format == MTLPixelFormatRGBA8Unorm ||
pixel_format == MTLPixelFormatRGBA8Snorm ||
pixel_format == MTLPixelFormatRGBA8Uint ||
pixel_format == MTLPixelFormatRGBA8Sint) {
return 4;
}
return -1;
}
MTLPixelFormat DataTypeToRGBAPixelFormat(DataType type, bool normalized) {
switch (type) {
case DataType::FLOAT32:
return MTLPixelFormatRGBA32Float;
case DataType::FLOAT16:
return MTLPixelFormatRGBA16Float;
case DataType::INT8:
return normalized ? MTLPixelFormatRGBA8Snorm : MTLPixelFormatRGBA8Sint;
case DataType::UINT8:
return normalized ? MTLPixelFormatRGBA8Unorm : MTLPixelFormatRGBA8Uint;
case DataType::INT16:
return normalized ? MTLPixelFormatRGBA16Snorm : MTLPixelFormatRGBA16Sint;
case DataType::UINT16:
return normalized ? MTLPixelFormatRGBA16Unorm : MTLPixelFormatRGBA16Uint;
case DataType::INT32:
return MTLPixelFormatRGBA32Sint;
case DataType::UINT32:
return MTLPixelFormatRGBA32Uint;
default:
return MTLPixelFormatInvalid;
}
}
} // namespace metal
} // namespace gpu
} // namespace tflite