blob: c63779ba6effb37b38df01097a6b0a010789f1e7 [file] [log] [blame]
// Copyright 2019 The Amber Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TYPE_H_
#define SRC_TYPE_H_
#include <cassert>
#include <memory>
#include <string>
#include <vector>
#include "src/format_data.h"
#include "src/make_unique.h"
namespace amber {
namespace type {
class List;
class Number;
class Struct;
class Type {
public:
Type();
virtual ~Type();
static bool IsSignedInt(FormatMode mode) {
return mode == FormatMode::kSInt || mode == FormatMode::kSNorm ||
mode == FormatMode::kSScaled;
}
static bool IsUnsignedInt(FormatMode mode) {
return mode == FormatMode::kUInt || mode == FormatMode::kUNorm ||
mode == FormatMode::kUScaled || mode == FormatMode::kSRGB;
}
static bool IsInt(FormatMode mode) {
return IsSignedInt(mode) || IsUnsignedInt(mode);
}
static bool IsFloat(FormatMode mode) {
return mode == FormatMode::kSFloat || mode == FormatMode::kUFloat;
}
static bool IsInt8(FormatMode mode, uint32_t num_bits) {
return IsSignedInt(mode) && num_bits == 8;
}
static bool IsInt16(FormatMode mode, uint32_t num_bits) {
return IsSignedInt(mode) && num_bits == 16;
}
static bool IsInt32(FormatMode mode, uint32_t num_bits) {
return IsSignedInt(mode) && num_bits == 32;
}
static bool IsInt64(FormatMode mode, uint32_t num_bits) {
return IsSignedInt(mode) && num_bits == 64;
}
static bool IsUint8(FormatMode mode, uint32_t num_bits) {
return IsUnsignedInt(mode) && num_bits == 8;
}
static bool IsUint16(FormatMode mode, uint32_t num_bits) {
return IsUnsignedInt(mode) && num_bits == 16;
}
static bool IsUint32(FormatMode mode, uint32_t num_bits) {
return IsUnsignedInt(mode) && num_bits == 32;
}
static bool IsUint64(FormatMode mode, uint32_t num_bits) {
return IsUnsignedInt(mode) && num_bits == 64;
}
static bool IsFloat16(FormatMode mode, uint32_t num_bits) {
return IsFloat(mode) && num_bits == 16;
}
static bool IsFloat32(FormatMode mode, uint32_t num_bits) {
return IsFloat(mode) && num_bits == 32;
}
static bool IsFloat64(FormatMode mode, uint32_t num_bits) {
return IsFloat(mode) && num_bits == 64;
}
// Returns the size in bytes of a single element of the type. This does not
// include space for arrays, vectors, etc.
virtual uint32_t SizeInBytes() const = 0;
virtual bool Equal(const Type* b) const = 0;
virtual bool IsList() const { return false; }
virtual bool IsNumber() const { return false; }
virtual bool IsStruct() const { return false; }
List* AsList();
Number* AsNumber();
Struct* AsStruct();
const List* AsList() const;
const Number* AsNumber() const;
const Struct* AsStruct() const;
void SetRowCount(uint32_t size) { row_count_ = size; }
uint32_t RowCount() const { return row_count_; }
void SetColumnCount(uint32_t size) { column_count_ = size; }
uint32_t ColumnCount() const { return column_count_; }
void SetIsRuntimeArray() { is_array_ = true; }
void SetIsSizedArray(uint32_t size) {
is_array_ = true;
array_size_ = size;
}
bool IsArray() const { return is_array_; }
bool IsSizedArray() const { return is_array_ && array_size_ > 0; }
bool IsRuntimeArray() const { return is_array_ && array_size_ == 0; }
uint32_t ArraySize() const { return array_size_; }
bool IsVec() const { return column_count_ == 1 && row_count_ > 1; }
// Returns true if this type holds a vec3.
bool IsVec3() const { return column_count_ == 1 && row_count_ == 3; }
// Returns true if this type holds a matrix.
bool IsMatrix() const { return column_count_ > 1 && row_count_ > 1; }
private:
uint32_t row_count_ = 1;
uint32_t column_count_ = 1;
uint32_t array_size_ = 0;
bool is_array_ = false;
};
class Number : public Type {
public:
explicit Number(FormatMode mode);
Number(FormatMode mode, uint32_t bits);
~Number() override;
static std::unique_ptr<Number> Int(uint32_t bits);
static std::unique_ptr<Number> Uint(uint32_t bits);
static std::unique_ptr<Number> Float(uint32_t bits);
bool IsNumber() const override { return true; }
uint32_t NumBits() const { return bits_; }
uint32_t SizeInBytes() const override { return bits_ / 8; }
bool Equal(const Type* b) const override {
if (!b->IsNumber())
return false;
auto n = b->AsNumber();
return format_mode_ == n->format_mode_ && bits_ == n->bits_;
}
FormatMode GetFormatMode() const { return format_mode_; }
private:
FormatMode format_mode_ = FormatMode::kSInt;
uint32_t bits_ = 32;
};
// The list type only holds lists of scalar float and int values.
class List : public Type {
public:
struct Member {
Member(FormatComponentType t, FormatMode m, uint32_t b)
: name(t), mode(m), num_bits(b) {}
uint32_t SizeInBytes() const { return num_bits / 8; }
FormatComponentType name = FormatComponentType::kR;
FormatMode mode = FormatMode::kSInt;
uint32_t num_bits = 0;
};
List();
~List() override;
bool IsList() const override { return true; }
bool Equal(const Type* b) const override {
if (!b->IsList())
return false;
auto l = b->AsList();
if (pack_size_in_bits_ != l->pack_size_in_bits_)
return false;
if (members_.size() != l->members_.size())
return false;
auto& lm = l->Members();
for (size_t i = 0; i < members_.size(); ++i) {
if (members_[i].name != lm[i].name)
return false;
if (members_[i].mode != lm[i].mode)
return false;
if (members_[i].num_bits != lm[i].num_bits)
return false;
}
return true;
}
void SetPackSizeInBits(uint32_t size) { pack_size_in_bits_ = size; }
uint32_t PackSizeInBits() const { return pack_size_in_bits_; }
bool IsPacked() const { return pack_size_in_bits_ > 0; }
void AddMember(FormatComponentType name, FormatMode mode, uint32_t num_bits) {
members_.push_back({name, mode, num_bits});
}
const std::vector<Member>& Members() const { return members_; }
std::vector<Member>& Members() { return members_; }
uint32_t SizeInBytes() const override;
private:
std::vector<Member> members_;
uint32_t pack_size_in_bits_ = 0;
};
class Struct : public Type {
public:
struct Member {
std::string name;
Type* type;
int32_t offset_in_bytes = -1;
int32_t array_stride_in_bytes = -1;
int32_t matrix_stride_in_bytes = -1;
bool HasOffset() const { return offset_in_bytes >= 0; }
bool HasArrayStride() const { return array_stride_in_bytes > 0; }
bool HasMatrixStride() const { return matrix_stride_in_bytes > 0; }
};
Struct();
~Struct() override;
uint32_t SizeInBytes() const override;
bool IsStruct() const override { return true; }
bool Equal(const Type* b) const override {
if (!b->IsStruct())
return false;
auto s = b->AsStruct();
if (is_stride_specified_ != s->is_stride_specified_)
return false;
if (stride_in_bytes_ != s->stride_in_bytes_)
return false;
if (members_.size() != s->members_.size())
return false;
auto& sm = s->Members();
for (size_t i = 0; i < members_.size(); ++i) {
if (members_[i].offset_in_bytes != sm[i].offset_in_bytes)
return false;
if (members_[i].array_stride_in_bytes != sm[i].array_stride_in_bytes)
return false;
if (members_[i].matrix_stride_in_bytes != sm[i].matrix_stride_in_bytes)
return false;
if (!members_[i].type->Equal(sm[i].type))
return false;
}
return true;
}
bool HasStride() const { return is_stride_specified_; }
uint32_t StrideInBytes() const { return stride_in_bytes_; }
void SetStrideInBytes(uint32_t stride) {
is_stride_specified_ = true;
stride_in_bytes_ = stride;
}
Member* AddMember(Type* type) {
members_.push_back({});
members_.back().type = type;
return &members_.back();
}
const std::vector<Member>& Members() const { return members_; }
private:
std::vector<Member> members_;
bool is_stride_specified_ = false;
uint32_t stride_in_bytes_ = 0;
};
} // namespace type
} // namespace amber
#endif // SRC_TYPE_H_