Add Float 16 Support (#751)
This CL adds the float16 data type into amber_script and the required plumbing to enable the feature.
Fixes #485
diff --git a/Android.mk b/Android.mk
index 801f5cf..81bdd11 100644
--- a/Android.mk
+++ b/Android.mk
@@ -30,6 +30,7 @@
src/descriptor_set_and_binding_parser.cc \
src/engine.cc \
src/executor.cc \
+ src/float16_helper.cc \
src/format.cc \
src/parser.cc \
src/pipeline.cc \
diff --git a/docs/amber_script.md b/docs/amber_script.md
index 4ac16df..3a403a4 100644
--- a/docs/amber_script.md
+++ b/docs/amber_script.md
@@ -34,8 +34,10 @@
```
Currently each of the items in `VkPhysicalDeviceFeatures` are recognized along
-with `VariablePointerFeatures.variablePointers` and
-`VariablePointerFeatures.variablePointersStorageBuffer`.
+with:
+ * `VariablePointerFeatures.variablePointers`
+ * `VariablePointerFeatures.variablePointersStorageBuffer`
+ * `Float16Int8Features.shaderFloat16`
Extensions can be enabled with the `DEVICE_EXTENSION` and `INSTANCE_EXTENSION`
commands.
@@ -114,6 +116,7 @@
* `uint16`
* `uint32`
* `uint64`
+ * `float16`
* `float`
* `double`
* vec[2,3,4]{type}
diff --git a/samples/config_helper_vulkan.cc b/samples/config_helper_vulkan.cc
index 606c093..fff21a6 100644
--- a/samples/config_helper_vulkan.cc
+++ b/samples/config_helper_vulkan.cc
@@ -46,6 +46,7 @@
const char kVariablePointers[] = "VariablePointerFeatures.variablePointers";
const char kVariablePointersStorageBuffer[] =
"VariablePointerFeatures.variablePointersStorageBuffer";
+const char kFloat16Int8_Float16[] = "Float16Int8Features.shaderFloat16";
const char kExtensionForValidationLayer[] = "VK_EXT_debug_report";
@@ -598,8 +599,8 @@
ConfigHelperVulkan::ConfigHelperVulkan()
: available_features_(VkPhysicalDeviceFeatures()),
available_features2_(VkPhysicalDeviceFeatures2KHR()),
- variable_pointers_feature_(VkPhysicalDeviceVariablePointerFeaturesKHR()) {
-}
+ variable_pointers_feature_(VkPhysicalDeviceVariablePointerFeaturesKHR()),
+ float16_int8_feature_(VkPhysicalDeviceFloat16Int8FeaturesKHR()) {}
ConfigHelperVulkan::~ConfigHelperVulkan() {
if (vulkan_device_)
@@ -666,9 +667,10 @@
// Determine if VkPhysicalDeviceProperties2KHR should be used
for (auto& ext : required_extensions) {
- if (ext == "VK_KHR_get_physical_device_properties2") {
+ if (ext == "VK_KHR_get_physical_device_properties2")
supports_get_physical_device_properties2_ = true;
- }
+ if (ext == "VK_KHR_shader_float16_int8")
+ supports_shader_float16_int8_ = true;
}
std::vector<const char*> required_extensions_in_char;
@@ -882,9 +884,17 @@
amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures2(
const std::vector<std::string>& required_features,
VkDeviceCreateInfo* info) {
+ float16_int8_feature_.sType =
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR;
+ float16_int8_feature_.pNext = nullptr;
+
variable_pointers_feature_.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTER_FEATURES_KHR;
- variable_pointers_feature_.pNext = nullptr;
+
+ if (supports_shader_float16_int8_)
+ variable_pointers_feature_.pNext = &float16_int8_feature_;
+ else
+ variable_pointers_feature_.pNext = nullptr;
available_features2_.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2_KHR;
available_features2_.pNext = &variable_pointers_feature_;
@@ -901,6 +911,8 @@
variable_pointers_feature_.variablePointers = VK_TRUE;
else if (feature == kVariablePointersStorageBuffer)
variable_pointers_feature_.variablePointersStorageBuffer = VK_TRUE;
+ else if (feature == kFloat16Int8_Float16)
+ float16_int8_feature_.shaderFloat16 = VK_TRUE;
}
VkPhysicalDeviceFeatures required_vulkan_features =
diff --git a/samples/config_helper_vulkan.h b/samples/config_helper_vulkan.h
index 8b830aa..b5d0b09 100644
--- a/samples/config_helper_vulkan.h
+++ b/samples/config_helper_vulkan.h
@@ -110,9 +110,11 @@
VkDevice vulkan_device_ = VK_NULL_HANDLE;
bool supports_get_physical_device_properties2_ = false;
+ bool supports_shader_float16_int8_ = false;
VkPhysicalDeviceFeatures available_features_;
VkPhysicalDeviceFeatures2KHR available_features2_;
VkPhysicalDeviceVariablePointerFeaturesKHR variable_pointers_feature_;
+ VkPhysicalDeviceFloat16Int8FeaturesKHR float16_int8_feature_;
};
} // namespace sample
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 7aeee3f..cd05481 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -21,6 +21,7 @@
descriptor_set_and_binding_parser.cc
engine.cc
executor.cc
+ float16_helper.cc
format.cc
parser.cc
pipeline.cc
@@ -138,6 +139,7 @@
command_data_test.cc
descriptor_set_and_binding_parser_test.cc
executor_test.cc
+ float16_helper_test.cc
format_test.cc
pipeline_test.cc
result_test.cc
diff --git a/src/amberscript/parser.cc b/src/amberscript/parser.cc
index e0fca81..12edfbc 100644
--- a/src/amberscript/parser.cc
+++ b/src/amberscript/parser.cc
@@ -70,6 +70,8 @@
return parser.Parse("R32_UINT");
if (str == "uint64")
return parser.Parse("R64_UINT");
+ if (str == "float16")
+ return parser.Parse("R16_SFLOAT");
if (str == "float")
return parser.Parse("R32_SFLOAT");
if (str == "double")
diff --git a/src/amberscript/parser_device_feature_test.cc b/src/amberscript/parser_device_feature_test.cc
index a15a602..6bd2d71 100644
--- a/src/amberscript/parser_device_feature_test.cc
+++ b/src/amberscript/parser_device_feature_test.cc
@@ -23,7 +23,8 @@
TEST_F(AmberScriptParserTest, DeviceFeature) {
std::string in = R"(
DEVICE_FEATURE vertexPipelineStoresAndAtomics
-DEVICE_FEATURE VariablePointerFeatures.variablePointersStorageBuffer)";
+DEVICE_FEATURE VariablePointerFeatures.variablePointersStorageBuffer
+DEVICE_FEATURE Float16Int8Features.shaderFloat16)";
Parser parser;
Result r = parser.Parse(in);
@@ -31,10 +32,11 @@
auto script = parser.GetScript();
const auto& features = script->GetRequiredFeatures();
- ASSERT_EQ(2U, features.size());
+ ASSERT_EQ(3U, features.size());
EXPECT_EQ("vertexPipelineStoresAndAtomics", features[0]);
EXPECT_EQ("VariablePointerFeatures.variablePointersStorageBuffer",
features[1]);
+ EXPECT_EQ("Float16Int8Features.shaderFloat16", features[2]);
}
TEST_F(AmberScriptParserTest, DeviceFeatureMissingFeature) {
diff --git a/src/buffer.cc b/src/buffer.cc
index 0b15b5f..71f295d 100644
--- a/src/buffer.cc
+++ b/src/buffer.cc
@@ -19,37 +19,11 @@
#include <cmath>
#include <cstring>
+#include "src/float16_helper.h"
+
namespace amber {
namespace {
-// Return sign value of 32 bits float.
-uint16_t FloatSign(const uint32_t hex_float) {
- return static_cast<uint16_t>(hex_float >> 31U);
-}
-
-// Return exponent value of 32 bits float.
-uint16_t FloatExponent(const uint32_t hex_float) {
- uint32_t exponent = ((hex_float >> 23U) & ((1U << 8U) - 1U)) - 112U;
- const uint32_t half_exponent_mask = (1U << 5U) - 1U;
- assert(((exponent & ~half_exponent_mask) == 0U) && "Float exponent overflow");
- return static_cast<uint16_t>(exponent & half_exponent_mask);
-}
-
-// Return mantissa value of 32 bits float. Note that mantissa for 32
-// bits float is 23 bits and this method must return uint32_t.
-uint32_t FloatMantissa(const uint32_t hex_float) {
- return static_cast<uint32_t>(hex_float & ((1U << 23U) - 1U));
-}
-
-// Convert 32 bits float |value| to 16 bits float based on IEEE-754.
-uint16_t FloatToHexFloat16(const float value) {
- const uint32_t* hex = reinterpret_cast<const uint32_t*>(&value);
- return static_cast<uint16_t>(
- static_cast<uint16_t>(FloatSign(*hex) << 15U) |
- static_cast<uint16_t>(FloatExponent(*hex) << 10U) |
- static_cast<uint16_t>(FloatMantissa(*hex) >> 13U));
-}
-
template <typename T>
T* ValuesAs(uint8_t* values) {
return reinterpret_cast<T*>(values);
@@ -82,10 +56,10 @@
return Sub<uint32_t>(buf1, buf2);
if (type::Type::IsUint64(mode, num_bits))
return Sub<uint64_t>(buf1, buf2);
- // TODO(dsinclair): Handle float16 ...
if (type::Type::IsFloat16(mode, num_bits)) {
- assert(false && "Float16 suppport not implemented");
- return 0.0;
+ float val1 = float16::HexFloatToFloat(buf1, 16);
+ float val2 = float16::HexFloatToFloat(buf2, 16);
+ return static_cast<double>(val1 - val2);
}
if (type::Type::IsFloat32(mode, num_bits))
return Sub<float>(buf1, buf2);
@@ -399,7 +373,7 @@
return sizeof(uint64_t);
}
if (type::Type::IsFloat16(mode, num_bits)) {
- *(ValuesAs<uint16_t>(ptr)) = FloatToHexFloat16(value.AsFloat());
+ *(ValuesAs<uint16_t>(ptr)) = float16::FloatToHexFloat16(value.AsFloat());
return sizeof(uint16_t);
}
if (type::Type::IsFloat32(mode, num_bits)) {
diff --git a/src/buffer_test.cc b/src/buffer_test.cc
index 9e5f614..6e1c121 100644
--- a/src/buffer_test.cc
+++ b/src/buffer_test.cc
@@ -18,6 +18,7 @@
#include <limits>
#include "gtest/gtest.h"
+#include "src/float16_helper.h"
#include "src/type_parser.h"
namespace amber {
@@ -294,4 +295,27 @@
EXPECT_TRUE(b1.CompareHistogramEMD(&b2, 0.0f).IsSuccess());
}
+TEST_F(BufferTest, SetFloat16) {
+ std::vector<Value> values;
+ values.resize(2);
+ values[0].SetDoubleValue(2.8);
+ values[1].SetDoubleValue(1234.567);
+
+ TypeParser parser;
+ auto type = parser.Parse("R16_SFLOAT");
+
+ Format fmt(type.get());
+ Buffer b;
+ b.SetFormat(&fmt);
+ b.SetData(std::move(values));
+
+ EXPECT_EQ(2, b.ElementCount());
+ EXPECT_EQ(2, b.ValueCount());
+ EXPECT_EQ(4, b.GetSizeInBytes());
+
+ auto v = b.GetValues<uint16_t>();
+ EXPECT_EQ(float16::FloatToHexFloat16(2.8f), v[0]);
+ EXPECT_EQ(float16::FloatToHexFloat16(1234.567f), v[1]);
+}
+
} // namespace amber
diff --git a/src/float16_helper.cc b/src/float16_helper.cc
new file mode 100644
index 0000000..a892aa1
--- /dev/null
+++ b/src/float16_helper.cc
@@ -0,0 +1,126 @@
+// Copyright 2019 The Amber Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/float16_helper.h"
+
+#include <cassert>
+
+// Float10
+// | 9 8 7 6 5 | 4 3 2 1 0 |
+// | exponent | mantissa |
+//
+// Float11
+// | 10 9 8 7 6 | 5 4 3 2 1 0 |
+// | exponent | mantissa |
+//
+// Float16
+// | 15 | 14 13 12 11 10 | 9 8 7 6 5 4 3 2 1 0 |
+// | s | exponent | mantissa |
+//
+// Float32
+// | 31 | 30 ... 23 | 22 ... 0 |
+// | s | exponent | mantissa |
+
+namespace amber {
+namespace float16 {
+namespace {
+
+// Return sign value of 32 bits float.
+uint16_t FloatSign(const uint32_t hex_float) {
+ return static_cast<uint16_t>(hex_float >> 31U);
+}
+
+// Return exponent value of 32 bits float.
+uint16_t FloatExponent(const uint32_t hex_float) {
+ uint32_t exponent = ((hex_float >> 23U) & ((1U << 8U) - 1U)) - 112U;
+ const uint32_t half_exponent_mask = (1U << 5U) - 1U;
+ assert(((exponent & ~half_exponent_mask) == 0U) && "Float exponent overflow");
+ return static_cast<uint16_t>(exponent & half_exponent_mask);
+}
+
+// Return mantissa value of 32 bits float. Note that mantissa for 32
+// bits float is 23 bits and this method must return uint32_t.
+uint32_t FloatMantissa(const uint32_t hex_float) {
+ return static_cast<uint32_t>(hex_float & ((1U << 23U) - 1U));
+}
+
+// Convert float |value| whose size is 16 bits to 32 bits float
+// based on IEEE-754.
+float HexFloat16ToFloat(const uint8_t* value) {
+ uint32_t sign = (static_cast<uint32_t>(value[1]) & 0x80) << 24U;
+ uint32_t exponent = (((static_cast<uint32_t>(value[1]) & 0x7c) >> 2U) + 112U)
+ << 23U;
+ uint32_t mantissa = ((static_cast<uint32_t>(value[1]) & 0x3) << 8U |
+ static_cast<uint32_t>(value[0]))
+ << 13U;
+
+ uint32_t hex = sign | exponent | mantissa;
+ float* hex_float = reinterpret_cast<float*>(&hex);
+ return *hex_float;
+}
+
+// Convert float |value| whose size is 11 bits to 32 bits float
+// based on IEEE-754.
+float HexFloat11ToFloat(const uint8_t* value) {
+ uint32_t exponent = (((static_cast<uint32_t>(value[1]) << 2U) |
+ ((static_cast<uint32_t>(value[0]) & 0xc0) >> 6U)) +
+ 112U)
+ << 23U;
+ uint32_t mantissa = (static_cast<uint32_t>(value[0]) & 0x3f) << 17U;
+
+ uint32_t hex = exponent | mantissa;
+ float* hex_float = reinterpret_cast<float*>(&hex);
+ return *hex_float;
+}
+
+// Convert float |value| whose size is 10 bits to 32 bits float
+// based on IEEE-754.
+float HexFloat10ToFloat(const uint8_t* value) {
+ uint32_t exponent = (((static_cast<uint32_t>(value[1]) << 3U) |
+ ((static_cast<uint32_t>(value[0]) & 0xe0) >> 5U)) +
+ 112U)
+ << 23U;
+ uint32_t mantissa = (static_cast<uint32_t>(value[0]) & 0x1f) << 18U;
+
+ uint32_t hex = exponent | mantissa;
+ float* hex_float = reinterpret_cast<float*>(&hex);
+ return *hex_float;
+}
+
+} // namespace
+
+float HexFloatToFloat(const uint8_t* value, uint8_t bits) {
+ switch (bits) {
+ case 10:
+ return HexFloat10ToFloat(value);
+ case 11:
+ return HexFloat11ToFloat(value);
+ case 16:
+ return HexFloat16ToFloat(value);
+ }
+
+ assert(false && "Invalid bits");
+ return 0;
+}
+
+uint16_t FloatToHexFloat16(const float value) {
+ const uint32_t* hex = reinterpret_cast<const uint32_t*>(&value);
+ return static_cast<uint16_t>(
+ static_cast<uint16_t>(FloatSign(*hex) << 15U) |
+ static_cast<uint16_t>(FloatExponent(*hex) << 10U) |
+ static_cast<uint16_t>(FloatMantissa(*hex) >> 13U));
+}
+
+} // namespace float16
+} // namespace amber
diff --git a/src/float16_helper.h b/src/float16_helper.h
new file mode 100644
index 0000000..66ae634
--- /dev/null
+++ b/src/float16_helper.h
@@ -0,0 +1,45 @@
+// Copyright 2019 The Amber Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_FLOAT16_HELPER_H_
+#define SRC_FLOAT16_HELPER_H_
+
+#include <cstdint>
+
+namespace amber {
+namespace float16 {
+
+// Convert float |value| whose size is |bits| bits to 32 bits float
+// based on IEEE-754.
+//
+// See https://www.khronos.org/opengl/wiki/Small_Float_Formats
+// and https://en.wikipedia.org/wiki/IEEE_754.
+//
+// Sign Exponent Mantissa Exponent-Bias
+// 16 1 5 10 15
+// 11 0 5 6 15
+// 10 0 5 5 15
+// 32 1 8 23 127
+// 64 1 11 52 1023
+//
+// 11 and 10 bits floats are always positive.
+float HexFloatToFloat(const uint8_t* value, uint8_t bits);
+
+// Convert 32 bits float |value| to 16 bits float based on IEEE-754.
+uint16_t FloatToHexFloat16(const float value);
+
+} // namespace float16
+} // namespace amber
+
+#endif // SRC_FLOAT16_HELPER_H_
diff --git a/src/float16_helper_test.cc b/src/float16_helper_test.cc
new file mode 100644
index 0000000..5fa8c33
--- /dev/null
+++ b/src/float16_helper_test.cc
@@ -0,0 +1,32 @@
+// Copyright 2019 The Amber Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/float16_helper.h"
+#include "gtest/gtest.h"
+
+namespace amber {
+namespace float16 {
+
+using Float16HelperTest = testing::Test;
+
+TEST_F(Float16HelperTest, F32ToF16AndBack) {
+ float a = 2.5;
+
+ uint16_t half = float16::FloatToHexFloat16(a);
+ float b = float16::HexFloatToFloat(reinterpret_cast<uint8_t*>(&half), 16);
+ EXPECT_FLOAT_EQ(a, b);
+}
+
+} // namespace float16
+} // namespace amber
diff --git a/src/script.cc b/src/script.cc
index 5b54875..318e30b 100644
--- a/src/script.cc
+++ b/src/script.cc
@@ -100,7 +100,8 @@
name == "sparseResidencyAliased" ||
name == "variableMultisampleRate" || name == "inheritedQueries" ||
name == "VariablePointerFeatures.variablePointers" ||
- name == "VariablePointerFeatures.variablePointersStorageBuffer";
+ name == "VariablePointerFeatures.variablePointersStorageBuffer" ||
+ name == "Float16Int8Features.shaderFloat16";
}
type::Type* Script::ParseType(const std::string& str) {
diff --git a/src/verifier.cc b/src/verifier.cc
index a787e9d..b836349 100644
--- a/src/verifier.cc
+++ b/src/verifier.cc
@@ -21,6 +21,7 @@
#include <vector>
#include "src/command.h"
+#include "src/float16_helper.h"
namespace amber {
namespace {
@@ -57,78 +58,6 @@
std::memcpy(dst, &data, static_cast<size_t>((bits + 7) / kBitsPerByte));
}
-// Convert float |value| whose size is 16 bits to 32 bits float
-// based on IEEE-754.
-float HexFloat16ToFloat(const uint8_t* value) {
- uint32_t sign = (static_cast<uint32_t>(value[1]) & 0x80) << 24U;
- uint32_t exponent = (((static_cast<uint32_t>(value[1]) & 0x7c) >> 2U) + 112U)
- << 23U;
- uint32_t mantissa = ((static_cast<uint32_t>(value[1]) & 0x3) << 8U |
- static_cast<uint32_t>(value[0]))
- << 13U;
-
- uint32_t hex = sign | exponent | mantissa;
- float* hex_float = reinterpret_cast<float*>(&hex);
- return *hex_float;
-}
-
-// Convert float |value| whose size is 11 bits to 32 bits float
-// based on IEEE-754.
-float HexFloat11ToFloat(const uint8_t* value) {
- uint32_t exponent = (((static_cast<uint32_t>(value[1]) << 2U) |
- ((static_cast<uint32_t>(value[0]) & 0xc0) >> 6U)) +
- 112U)
- << 23U;
- uint32_t mantissa = (static_cast<uint32_t>(value[0]) & 0x3f) << 17U;
-
- uint32_t hex = exponent | mantissa;
- float* hex_float = reinterpret_cast<float*>(&hex);
- return *hex_float;
-}
-
-// Convert float |value| whose size is 10 bits to 32 bits float
-// based on IEEE-754.
-float HexFloat10ToFloat(const uint8_t* value) {
- uint32_t exponent = (((static_cast<uint32_t>(value[1]) << 3U) |
- ((static_cast<uint32_t>(value[0]) & 0xe0) >> 5U)) +
- 112U)
- << 23U;
- uint32_t mantissa = (static_cast<uint32_t>(value[0]) & 0x1f) << 18U;
-
- uint32_t hex = exponent | mantissa;
- float* hex_float = reinterpret_cast<float*>(&hex);
- return *hex_float;
-}
-
-// Convert float |value| whose size is |bits| bits to 32 bits float
-// based on IEEE-754.
-// See https://www.khronos.org/opengl/wiki/Small_Float_Formats
-// and https://en.wikipedia.org/wiki/IEEE_754.
-//
-// Sign Exponent Mantissa Exponent-Bias
-// 16 1 5 10 15
-// 11 0 5 6 15
-// 10 0 5 5 15
-// 32 1 8 23 127
-// 64 1 11 52 1023
-//
-// 11 and 10 bits floats are always positive.
-// 14 bits float is used only RGB9_E5 format in OpenGL but it does not exist
-// in Vulkan.
-float HexFloatToFloat(const uint8_t* value, uint8_t bits) {
- switch (bits) {
- case 10:
- return HexFloat10ToFloat(value);
- case 11:
- return HexFloat11ToFloat(value);
- case 16:
- return HexFloat16ToFloat(value);
- }
-
- assert(false && "Invalid bits");
- return 0;
-}
-
// This is based on "18.3. sRGB transfer functions" of
// https://www.khronos.org/registry/DataFormat/specs/1.2/dataformat.1.2.html
double SRGBToLinearValue(double sRGB) {
@@ -158,67 +87,84 @@
}
template <typename T>
-Result CheckValue(const ProbeSSBOCommand* command,
- const uint8_t* memory,
- const Value& value) {
+Result CheckActualValue(const ProbeSSBOCommand* command,
+ const T actual_value,
+ const Value& value) {
const auto comp = command->GetComparator();
const auto& tolerance = command->GetTolerances();
- const T* ptr = reinterpret_cast<const T*>(memory);
const T val = value.IsInteger() ? static_cast<T>(value.AsUint64())
: static_cast<T>(value.AsDouble());
switch (comp) {
case ProbeSSBOCommand::Comparator::kEqual:
if (value.IsInteger()) {
- if (static_cast<uint64_t>(*ptr) != static_cast<uint64_t>(val)) {
- return Result(std::to_string(*ptr) + " == " + std::to_string(val));
+ if (static_cast<uint64_t>(actual_value) != static_cast<uint64_t>(val)) {
+ return Result(std::to_string(actual_value) +
+ " == " + std::to_string(val));
}
} else {
- if (!IsEqualWithTolerance(static_cast<const double>(*ptr),
+ if (!IsEqualWithTolerance(static_cast<const double>(actual_value),
static_cast<const double>(val), kEpsilon)) {
- return Result(std::to_string(*ptr) + " == " + std::to_string(val));
+ return Result(std::to_string(actual_value) +
+ " == " + std::to_string(val));
}
}
break;
case ProbeSSBOCommand::Comparator::kNotEqual:
if (value.IsInteger()) {
- if (static_cast<uint64_t>(*ptr) == static_cast<uint64_t>(val)) {
- return Result(std::to_string(*ptr) + " != " + std::to_string(val));
+ if (static_cast<uint64_t>(actual_value) == static_cast<uint64_t>(val)) {
+ return Result(std::to_string(actual_value) +
+ " != " + std::to_string(val));
}
} else {
- if (IsEqualWithTolerance(static_cast<const double>(*ptr),
+ if (IsEqualWithTolerance(static_cast<const double>(actual_value),
static_cast<const double>(val), kEpsilon)) {
- return Result(std::to_string(*ptr) + " != " + std::to_string(val));
+ return Result(std::to_string(actual_value) +
+ " != " + std::to_string(val));
}
}
break;
case ProbeSSBOCommand::Comparator::kFuzzyEqual:
if (!IsEqualWithTolerance(
- static_cast<const double>(*ptr), static_cast<const double>(val),
+ static_cast<const double>(actual_value),
+ static_cast<const double>(val),
command->HasTolerances() ? tolerance[0].value : kEpsilon,
command->HasTolerances() ? tolerance[0].is_percent : true)) {
- return Result(std::to_string(*ptr) + " ~= " + std::to_string(val));
+ return Result(std::to_string(actual_value) +
+ " ~= " + std::to_string(val));
}
break;
case ProbeSSBOCommand::Comparator::kLess:
- if (*ptr >= val)
- return Result(std::to_string(*ptr) + " < " + std::to_string(val));
+ if (actual_value >= val)
+ return Result(std::to_string(actual_value) + " < " +
+ std::to_string(val));
break;
case ProbeSSBOCommand::Comparator::kLessOrEqual:
- if (*ptr > val)
- return Result(std::to_string(*ptr) + " <= " + std::to_string(val));
+ if (actual_value > val)
+ return Result(std::to_string(actual_value) +
+ " <= " + std::to_string(val));
break;
case ProbeSSBOCommand::Comparator::kGreater:
- if (*ptr <= val)
- return Result(std::to_string(*ptr) + " > " + std::to_string(val));
+ if (actual_value <= val)
+ return Result(std::to_string(actual_value) + " > " +
+ std::to_string(val));
break;
case ProbeSSBOCommand::Comparator::kGreaterOrEqual:
- if (*ptr < val)
- return Result(std::to_string(*ptr) + " >= " + std::to_string(val));
+ if (actual_value < val)
+ return Result(std::to_string(actual_value) +
+ " >= " + std::to_string(val));
break;
}
return {};
}
+template <typename T>
+Result CheckValue(const ProbeSSBOCommand* command,
+ const uint8_t* memory,
+ const Value& value) {
+ const T* ptr = reinterpret_cast<const T*>(memory);
+ return CheckActualValue<T>(command, *ptr, value);
+}
+
void SetupToleranceForTexels(const ProbeCommand* command,
double* tolerance,
bool* is_tolerance_percent) {
@@ -314,7 +260,7 @@
actual_values[i] = *ptr;
} else if (type::Type::IsFloat(mode) && num_bits < 32) {
actual_values[i] = static_cast<double>(
- HexFloatToFloat(actual, static_cast<uint8_t>(num_bits)));
+ float16::HexFloatToFloat(actual, static_cast<uint8_t>(num_bits)));
} else {
assert(false && "Incorrect number of bits for number.");
}
@@ -616,28 +562,32 @@
Result r;
FormatMode mode = segment.GetFormatMode();
uint32_t num_bits = segment.GetNumBits();
- if (type::Type::IsInt8(mode, num_bits))
+ if (type::Type::IsInt8(mode, num_bits)) {
r = CheckValue<int8_t>(command, ptr, value);
- else if (type::Type::IsUint8(mode, num_bits))
+ } else if (type::Type::IsUint8(mode, num_bits)) {
r = CheckValue<uint8_t>(command, ptr, value);
- else if (type::Type::IsInt16(mode, num_bits))
+ } else if (type::Type::IsInt16(mode, num_bits)) {
r = CheckValue<int16_t>(command, ptr, value);
- else if (type::Type::IsUint16(mode, num_bits))
+ } else if (type::Type::IsUint16(mode, num_bits)) {
r = CheckValue<uint16_t>(command, ptr, value);
- else if (type::Type::IsInt32(mode, num_bits))
+ } else if (type::Type::IsInt32(mode, num_bits)) {
r = CheckValue<int32_t>(command, ptr, value);
- else if (type::Type::IsUint32(mode, num_bits))
+ } else if (type::Type::IsUint32(mode, num_bits)) {
r = CheckValue<uint32_t>(command, ptr, value);
- else if (type::Type::IsInt64(mode, num_bits))
+ } else if (type::Type::IsInt64(mode, num_bits)) {
r = CheckValue<int64_t>(command, ptr, value);
- else if (type::Type::IsUint64(mode, num_bits))
+ } else if (type::Type::IsUint64(mode, num_bits)) {
r = CheckValue<uint64_t>(command, ptr, value);
- else if (type::Type::IsFloat32(mode, num_bits))
+ } else if (type::Type::IsFloat16(mode, num_bits)) {
+ r = CheckActualValue<float>(command, float16::HexFloatToFloat(ptr, 16),
+ value);
+ } else if (type::Type::IsFloat32(mode, num_bits)) {
r = CheckValue<float>(command, ptr, value);
- else if (type::Type::IsFloat64(mode, num_bits))
+ } else if (type::Type::IsFloat64(mode, num_bits)) {
r = CheckValue<double>(command, ptr, value);
- else
+ } else {
return Result("Unknown datum type");
+ }
if (!r.IsSuccess()) {
return Result("Line " + std::to_string(command->GetLine()) +
diff --git a/src/verifier_test.cc b/src/verifier_test.cc
index 42e5aba..42d4dfc 100644
--- a/src/verifier_test.cc
+++ b/src/verifier_test.cc
@@ -22,6 +22,7 @@
#include "amber/value.h"
#include "gtest/gtest.h"
#include "src/command.h"
+#include "src/float16_helper.h"
#include "src/make_unique.h"
#include "src/pipeline.h"
#include "src/type_parser.h"
@@ -1520,4 +1521,35 @@
EXPECT_TRUE(r.IsSuccess()) << r.Error();
}
+TEST_F(VerifierTest, ProbeSSBOHexFloat) {
+ Pipeline pipeline(PipelineType::kGraphics);
+ auto color_buf = pipeline.GenerateDefaultColorAttachmentBuffer();
+
+ ProbeSSBOCommand probe_ssbo(color_buf.get());
+
+ TypeParser parser;
+ auto type = parser.Parse("R16_SFLOAT");
+ Format fmt(type.get());
+
+ probe_ssbo.SetFormat(&fmt);
+ probe_ssbo.SetComparator(ProbeSSBOCommand::Comparator::kFuzzyEqual);
+ probe_ssbo.SetTolerances({ProbeCommand::Tolerance{false, 0.1}});
+
+ std::vector<Value> values;
+ values.resize(4);
+ values[0].SetDoubleValue(2.5);
+ values[1].SetDoubleValue(0.73);
+ values[2].SetDoubleValue(10.0);
+ values[3].SetDoubleValue(123.5);
+ probe_ssbo.SetValues(std::move(values));
+
+ const uint16_t ssbo[4] = {
+ float16::FloatToHexFloat16(2.5f), float16::FloatToHexFloat16(0.73f),
+ float16::FloatToHexFloat16(10.0f), float16::FloatToHexFloat16(123.5f)};
+
+ Verifier verifier;
+ Result r = verifier.ProbeSSBO(&probe_ssbo, sizeof(uint16_t) * 4, ssbo);
+ EXPECT_TRUE(r.IsSuccess()) << r.Error();
+}
+
} // namespace amber
diff --git a/src/vulkan/device.cc b/src/vulkan/device.cc
index f5f6e1c..5cf30f2 100644
--- a/src/vulkan/device.cc
+++ b/src/vulkan/device.cc
@@ -32,6 +32,7 @@
const char kVariablePointers[] = "VariablePointerFeatures.variablePointers";
const char kVariablePointersStorageBuffer[] =
"VariablePointerFeatures.variablePointersStorageBuffer";
+const char kFloat16Int8_Float16[] = "Float16Int8Features.shaderFloat16";
struct BaseOutStructure {
VkStructureType sType;
@@ -383,11 +384,9 @@
return r;
bool use_physical_device_features_2 = false;
- // Determine if VkPhysicalDeviceProperties2KHR should be used
for (auto& ext : required_extensions) {
- if (ext == "VK_KHR_get_physical_device_properties2") {
+ if (ext == "VK_KHR_get_physical_device_properties2")
use_physical_device_features_2 = true;
- }
}
VkPhysicalDeviceFeatures available_vulkan_features =
@@ -396,6 +395,7 @@
available_vulkan_features = available_features2.features;
VkPhysicalDeviceVariablePointerFeaturesKHR* var_ptrs = nullptr;
+ VkPhysicalDeviceFloat16Int8FeaturesKHR* float16_ptrs = nullptr;
void* ptr = available_features2.pNext;
while (ptr != nullptr) {
BaseOutStructure* s = static_cast<BaseOutStructure*>(ptr);
@@ -403,7 +403,10 @@
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTER_FEATURES_KHR) {
var_ptrs =
static_cast<VkPhysicalDeviceVariablePointerFeaturesKHR*>(ptr);
- break;
+ } else if (s->sType ==
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR) {
+ float16_ptrs =
+ static_cast<VkPhysicalDeviceFloat16Int8FeaturesKHR*>(ptr);
}
ptr = s->pNext;
}
@@ -432,6 +435,16 @@
return amber::Result(
"Missing variable pointers storage buffer feature");
}
+
+ if (feature == kFloat16Int8_Float16) {
+ if (float16_ptrs == nullptr) {
+ return amber::Result(
+ "Shader float 16 requested but feature not returned");
+ }
+ if (float16_ptrs->shaderFloat16 != VK_TRUE) {
+ return amber::Result("Missing float16 feature");
+ }
+ }
}
} else {
diff --git a/tests/cases/float16.amber b/tests/cases/float16.amber
new file mode 100644
index 0000000..53671db
--- /dev/null
+++ b/tests/cases/float16.amber
@@ -0,0 +1,45 @@
+#!amber
+# Copyright 2019 The Amber Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+INSTANCE_EXTENSION VK_KHR_get_physical_device_properties2
+DEVICE_EXTENSION VK_KHR_shader_float16_int8
+DEVICE_FEATURE Float16Int8Features.shaderFloat16
+
+SHADER compute f16 GLSL
+#version 450
+#extension GL_AMD_gpu_shader_half_float : enable
+
+layout(set=0, binding=0) buffer Buf {
+ float16_t h;
+} data;
+
+void main() {
+ data.h = data.h * 2.0hf;
+}
+END
+
+BUFFER buf DATA_TYPE float16 DATA
+2.4
+END
+
+PIPELINE compute pipeline
+ ATTACH f16
+
+ BIND BUFFER buf AS storage DESCRIPTOR_SET 0 BINDING 0
+END
+
+RUN pipeline 1 1 1
+
+EXPECT buf IDX 0 TOLERANCE 0.1 EQ 4.8
diff --git a/tests/run_tests.py b/tests/run_tests.py
index a988b82..f461bff 100755
--- a/tests/run_tests.py
+++ b/tests/run_tests.py
@@ -36,8 +36,6 @@
"compute_mat3x2.vkscript",
"compute_mat3x2float.vkscript",
"compute_mat3x2.amber",
- # https://github.com/KhronosGroup/SPIRV-Tools/issues/3072
- "compute_robust_buffer_access_ssbo.amber",
# Metal vertex shaders cannot simultaneously write to a buffer and return
# a value to the rasterizer rdar://48348476
# https://github.com/KhronosGroup/MoltenVK/issues/527
@@ -47,14 +45,10 @@
"draw_triangle_list_hlsl.amber",
],
"Linux": [
- # https://github.com/KhronosGroup/SPIRV-Tools/issues/3072
- "compute_robust_buffer_access_ssbo.amber",
# DXC not currently building on bot
"draw_triangle_list_hlsl.amber",
],
"Win": [
- # https://github.com/KhronosGroup/SPIRV-Tools/issues/3072
- "compute_robust_buffer_access_ssbo.amber",
# DXC not currently building on bot
"draw_triangle_list_hlsl.amber",
]
@@ -76,6 +70,8 @@
# Color attachment format is not supported
"draw_triangle_list_in_r16g16b16a16_snorm_color_frame.vkscript",
"draw_triangle_list_in_r8g8b8a8_snorm_color_frame.vkscript",
+ # No supporting device for Float16Int8Features
+ "float16.amber",
# SEGV: github.com/google/amber/issues/726
"matrices_uniform_draw.amber",
# SEGV: github.com/google/amber/issues/725