| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" |
| |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_format.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/types/variant.h" |
| #include "tensorflow/lite/delegates/gpu/common/types.h" |
| |
| namespace tflite { |
| namespace gpu { |
| namespace gl { |
| namespace variable_accessor_internal { |
| |
| // Parse the following regex manually |
| // name(\[index\])?(\.field)? |
| VariableReference Parse(absl::string_view input) { |
| VariableReference ref; |
| auto start_index = input.find('['); |
| if (start_index != std::string::npos) { |
| auto end_index = input.rfind(']'); |
| if (end_index == std::string::npos) { |
| return ref; |
| } |
| ref.index = input.substr(start_index + 1, end_index - start_index - 1); |
| ref.name = input.substr(0, start_index); |
| ref.field = input.substr(end_index + 1); |
| } else { |
| auto dot = input.find('.'); |
| if (dot != std::string::npos) { |
| ref.name = input.substr(0, dot); |
| ref.field = input.substr(dot); |
| } else { |
| ref.name = input; |
| } |
| } |
| return ref; |
| } |
| |
| } // namespace variable_accessor_internal |
| |
| namespace { |
| |
| struct VariableTypeGetter { |
| std::string operator()(int) const { return "int"; } |
| std::string operator()(const int2&) const { return "ivec2"; } |
| std::string operator()(const std::vector<int2>&) const { return "ivec2"; } |
| std::string operator()(const int4&) const { return "ivec4"; } |
| std::string operator()(unsigned int) const { return "uint"; } |
| std::string operator()(const uint4&) const { return "uvec4"; } |
| std::string operator()(float) const { return "float"; } |
| std::string operator()(const float2&) const { return "vec2"; } |
| std::string operator()(const float4&) const { return "vec4"; } |
| std::string operator()(const std::vector<float4>&) const { return "vec4"; } |
| }; |
| |
| // Returns GLSL uniform type of the given variable. |
| std::string GetVariableType(const Variable::ValueType& value) { |
| return absl::visit(VariableTypeGetter(), value); |
| } |
| |
| template <typename T> |
| void FormatValue(std::string* result, T t) { |
| absl::StrAppend(result, t); |
| } |
| |
| template <> |
| void FormatValue(std::string* result, float t) { |
| absl::StrAppend(result, absl::StrFormat("%.9ff", t)); |
| } |
| |
| // Unfortunately absl::StrJoin with custom formatter requires formatter to use |
| // string, not std::string. Therefore, due to this compatibility issue data |
| // needs to be converted to string representation first and then joined. |
| template <typename T, int N> |
| std::vector<std::string> ToString(const std::array<T, N>& data) { |
| std::vector<std::string> result(N); |
| for (int i = 0; i < N; ++i) { |
| FormatValue(&result[i], data[i]); |
| } |
| return result; |
| } |
| |
| struct ConstGenerator { |
| template <typename T> |
| void operator()(T t) const { |
| FormatValue(result, t); |
| } |
| |
| template <typename T> |
| void operator()(const Vec2<T>& v) const { |
| absl::StrAppend(result, VariableTypeGetter()(v), "(", |
| absl::StrJoin(ToString<T, 2>(v.data_), ","), ")"); |
| } |
| |
| template <typename T> |
| void operator()(const Vec3<T>& v) const { |
| absl::StrAppend(result, VariableTypeGetter()(v), "(", |
| absl::StrJoin(ToString<T, 3>(v.data_), ","), ")"); |
| } |
| |
| template <typename T> |
| void operator()(const Vec4<T>& v) const { |
| absl::StrAppend(result, VariableTypeGetter()(v), "(", |
| absl::StrJoin(ToString<T, 4>(v.data_), ","), ")"); |
| } |
| |
| template <typename T> |
| void operator()(const std::vector<T>& v) const { |
| std::string type = VariableTypeGetter()(v); |
| absl::StrAppend(result, type, "[", v.size(), "]("); |
| bool first = true; |
| for (const auto& i : v) { |
| if (first) { |
| first = false; |
| } else { |
| absl::StrAppend(result, ","); |
| } |
| (*this)(i); |
| } |
| absl::StrAppend(result, ")"); |
| } |
| |
| std::string* result; |
| }; |
| |
| // Appends string representation of a variable value. |
| void GetValue(const Variable::ValueType& value, std::string* result) { |
| absl::visit(ConstGenerator{result}, value); |
| } |
| |
| struct SharedVariableDeclarationGenerator { |
| template <typename T> |
| void operator()(const T&) const { |
| absl::StrAppend(result, "shared ", GetVariableType(variable.value), " ", |
| variable.name, ";\n"); |
| } |
| |
| template <typename T> |
| void operator()(const std::vector<T>& v) const { |
| absl::StrAppend(result, "shared ", GetVariableType(variable.value), " ", |
| variable.name, "[", v.size(), "];\n"); |
| } |
| |
| const Variable& variable; |
| std::string* result; |
| }; |
| |
| void GenerateSharedVariableDeclaration(const Variable& variable, |
| std::string* result) { |
| absl::visit(SharedVariableDeclarationGenerator{variable, result}, |
| variable.value); |
| } |
| |
| struct UniformParameterDeclarationGenerator { |
| template <typename T> |
| void operator()(const T&) const { |
| absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ", |
| variable.name, ";\n"); |
| } |
| |
| template <typename T> |
| void operator()(const std::vector<T>& v) const { |
| absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ", |
| variable.name, "[", v.size(), "];\n"); |
| } |
| |
| const Variable& variable; |
| std::string* result; |
| }; |
| |
| void GenerateUniformParameterDeclaration(const Variable& variable, |
| std::string* result) { |
| absl::visit(UniformParameterDeclarationGenerator{variable, result}, |
| variable.value); |
| } |
| |
| struct VariableLengthGetter { |
| template <typename T> |
| bool operator()(const T&) const { |
| return false; |
| } |
| template <typename T> |
| bool operator()(const std::vector<T>&) const { |
| return true; |
| } |
| }; |
| |
| // Returns true if value is a vector |
| bool IsVariableLength(const Variable::ValueType& value) { |
| return absl::visit(VariableLengthGetter(), value); |
| } |
| |
| enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 }; |
| |
| Field ToField(absl::string_view field_name) { |
| if (field_name.size() == 2 && field_name[0] == '.') { |
| switch (field_name[1]) { |
| case 'x': |
| return Field::X; |
| case 'y': |
| return Field::Y; |
| case 'z': |
| return Field::Z; |
| case 'w': |
| return Field::W; |
| } |
| } |
| return Field::UNKNOWN; |
| } |
| |
| struct FieldAccessor { |
| template <typename T> |
| void operator()(const T&) const {} |
| |
| template <typename T> |
| void operator()(const Vec2<T>& v) const { |
| FormatValue(result, v[field]); |
| } |
| |
| template <typename T> |
| void operator()(const Vec3<T>& v) const { |
| FormatValue(result, v[field]); |
| } |
| |
| template <typename T> |
| void operator()(const Vec4<T>& v) const { |
| FormatValue(result, v[field]); |
| } |
| |
| Field field; |
| std::string* result; |
| }; |
| |
| // Appends formatted value of the given field. |
| void GetValue(const Variable::ValueType& value, Field field, |
| std::string* result) { |
| absl::visit(FieldAccessor{field, result}, value); |
| } |
| |
| struct FieldChecker { |
| // For trivial as well as variable-length types indexed access is not allowed. |
| template <typename T> |
| bool operator()(const T&) const { |
| return false; |
| } |
| |
| template <typename T> |
| bool operator()(const Vec2<T>& v) const { |
| return field < v.size(); |
| } |
| |
| template <typename T> |
| bool operator()(const Vec3<T>& v) const { |
| return field < v.size(); |
| } |
| |
| template <typename T> |
| bool operator()(const Vec4<T>& v) const { |
| return field < v.size(); |
| } |
| |
| template <typename T> |
| bool operator()(const std::vector<T>&) const { |
| // technically accessing [0] element of an empty vector is UB, but we need |
| // only type information for this check. Therefore, construct default T and |
| // use it instead. |
| T t; |
| return (*this)(t); |
| } |
| |
| Field field; |
| }; |
| |
| // Returns true if field has field access and field is not out of bounds. |
| bool HasField(const Variable::ValueType& value, Field field) { |
| return absl::visit(FieldChecker{field}, value); |
| } |
| |
| void AssembleAccessor(absl::string_view name, absl::string_view index, |
| absl::string_view field, std::string* result) { |
| if (index.empty()) { |
| absl::StrAppend(result, name, field); |
| } else { |
| absl::StrAppend(result, name, "[", index, "]", field); |
| } |
| } |
| |
| } // namespace |
| |
| RewriteStatus VariableAccessor::Rewrite(absl::string_view input, |
| std::string* output) { |
| auto ref = variable_accessor_internal::Parse(input); |
| if (ref.name.empty()) { |
| absl::StrAppend(output, "INVALID_SYNTAX"); |
| return RewriteStatus::ERROR; |
| } |
| |
| auto it = |
| name_to_variable_.find(std::string(ref.name.data(), ref.name.size())); |
| if (it == name_to_variable_.end()) { |
| // Uniform with this name is not registered. |
| return RewriteStatus::NOT_RECOGNIZED; |
| } |
| const auto& value = it->second.value; |
| |
| if (!ref.index.empty() && !IsVariableLength(value)) { |
| // Trying to access variable by index, but it is not variable-length. |
| absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX"); |
| return RewriteStatus::ERROR; |
| } |
| |
| Field f = ToField(ref.field); |
| if (!ref.field.empty() && !HasField(value, f)) { |
| // Trying to access a variable by field, but it does not have it. |
| absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD"); |
| return RewriteStatus::ERROR; |
| } |
| |
| // Error checks are complete now. |
| |
| // All variable-length variables are encoded as-is without inlining. |
| if (!inline_values_ || IsVariableLength(value)) { |
| AssembleAccessor(it->second.name, ref.index, ref.field, output); |
| } else { |
| // Parameter + field is replaced with field value. |
| if (f != Field::UNKNOWN) { |
| GetValue(value, f, output); |
| } else { |
| // Parameter is accessed directly. |
| GetValue(value, output); |
| } |
| } |
| return RewriteStatus::SUCCESS; |
| } |
| |
| bool VariableAccessor::AddSharedVariable(Variable&& variable) { |
| const std::string name = variable.name; |
| if (!name_to_variable_.insert({name, std::move(variable)}).second) { |
| return false; |
| } |
| shared_variables_.insert(name); |
| return true; |
| } |
| |
| bool VariableAccessor::AddUniformParameter(Variable&& variable) { |
| const std::string name = variable.name; |
| if (!name_to_variable_.insert({name, std::move(variable)}).second) { |
| return false; |
| } |
| uniform_parameters_.insert(name); |
| return true; |
| } |
| |
| std::string VariableAccessor::GetConstDeclarations() const { |
| // Variable length variables are declared as const and accessed via variable |
| // with index. |
| std::string declarations; |
| for (const auto& variable : name_to_variable_) { |
| // Skip shared variables. |
| if (shared_variables_.find(variable.second.name) != |
| shared_variables_.end()) { |
| continue; |
| } |
| const auto& value = variable.second.value; |
| if (IsVariableLength(value)) { |
| absl::StrAppend(&declarations, "const ", GetVariableType(value), " ", |
| variable.second.name, "[] = "); |
| GetValue(value, &declarations); |
| absl::StrAppend(&declarations, ";\n"); |
| } |
| } |
| return declarations; |
| } |
| |
| std::string VariableAccessor::GetSharedVariableDeclarations() const { |
| std::string declarations; |
| for (const auto& name : shared_variables_) { |
| const auto& variable = name_to_variable_.at(name); |
| GenerateSharedVariableDeclaration(variable, &declarations); |
| } |
| return declarations; |
| } |
| |
| std::string VariableAccessor::GetUniformParameterDeclarations() const { |
| std::string declarations; |
| if (!inline_values_) { |
| for (const auto& name : uniform_parameters_) { |
| const auto& variable = name_to_variable_.at(name); |
| GenerateUniformParameterDeclaration(variable, &declarations); |
| } |
| } |
| return declarations; |
| } |
| |
| std::vector<Variable> VariableAccessor::GetUniformParameters() const { |
| std::vector<Variable> variables; |
| if (!inline_values_) { |
| variables.reserve(name_to_variable_.size()); |
| for (const auto& variable : name_to_variable_) { |
| variables.push_back(variable.second); |
| } |
| } |
| return variables; |
| } |
| |
| } // namespace gl |
| } // namespace gpu |
| } // namespace tflite |