blob: 8773f0c0b043e673b25b8e910c3b9c14a4e7f36f [file]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Types.h>
namespace vkcompute {
using ValueRef = int32_t;
constexpr ValueRef kDummyValueRef = -1;
inline bool is_valid(ValueRef value_ref) {
return value_ref >= 0;
}
struct IOValueRef {
ValueRef value;
ValueRef staging;
// Custom cast to ValueRef
operator ValueRef() const {
return value;
};
};
/*
* This class is modelled after c10::IValue; however, it is simplified and does
* not support as many types. However, the core design is the same; it is a
* tagged union over the types supported by the Vulkan Graph type.
*/
struct Value final {
private:
/*
* The union type which is used to store the value of the Value.
*/
union Payload {
/*
* Similar to IValue::Payload, trivially copyable types are nested in their
* own union.
*/
union TriviallyCopyablePayload {
TriviallyCopyablePayload() : as_int(0) {}
int64_t as_int;
double as_double;
bool as_bool;
} u;
api::vTensor as_tensor;
api::StagingBuffer as_staging;
TensorRef as_tensorref;
std::vector<int64_t> as_int_list;
std::vector<double> as_double_list;
std::vector<bool> as_bool_list;
// The below is a special type that is used to represent a list of other
// values stored in the graph. One application of the type is to represent
// a list of tensors or a list of optional tensors.
std::vector<ValueRef> as_value_list;
std::string as_string;
SymInt as_symint;
Payload() : u() {}
// NOLINTNEXTLINE
~Payload(){};
};
public:
//
// Copy constructor and assignment (disabled)
//
Value(const Value& rhs) = delete;
Value& operator=(const Value&) = delete;
//
// Move constructor and assignment; Move assignment is disabled but
// construction is implemented to allow for use in container types.
//
Value& operator=(Value&&) = delete;
#define CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(type_tag, member_name) \
case type_tag: \
payload.u.member_name = rhs.payload.u.member_name; \
break;
#define CASE_MOVE_MOVEABLE_TYPE(type_tag, type, member_name, dtor_name) \
case type_tag: \
new (&payload.member_name) type(std::move(rhs.payload.member_name)); \
rhs.payload.member_name.~dtor_name(); \
break;
Value(Value&& rhs) noexcept : tag(rhs.tag) {
switch (tag) {
// Scalar types
CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int);
CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double);
CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool);
// Tensor and tensor adjacent types
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::TENSOR, api::vTensor, as_tensor, vTensor);
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer);
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::TENSORREF, TensorRef, as_tensorref, TensorRef);
// Scalar lists
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::INTLIST, std::vector<int64_t>, as_int_list, vector);
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::DOUBLELIST, std::vector<double>, as_double_list, vector);
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::BOOLLIST, std::vector<bool>, as_bool_list, vector);
// Special types
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::VALUELIST, std::vector<ValueRef>, as_value_list, vector);
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::STRING, std::string, as_string, basic_string);
CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt);
case TypeTag::NONE:
clearToNone();
break;
}
rhs.clearToNone();
}
#undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE
#undef CASE_MOVE_MOVEABLE_TYPE
//
// Accessors
//
inline TypeTag type() const {
return tag;
}
//
// Destructor
//
~Value() {
switch (tag) {
case TypeTag::TENSOR:
payload.as_tensor.~vTensor();
break;
case TypeTag::STAGING:
payload.as_staging.~StagingBuffer();
break;
case TypeTag::TENSORREF:
payload.as_tensorref.~TensorRef();
break;
case TypeTag::INTLIST:
payload.as_int_list.~vector();
break;
case TypeTag::DOUBLELIST:
payload.as_double_list.~vector();
break;
case TypeTag::BOOLLIST:
payload.as_bool_list.~vector();
break;
case TypeTag::VALUELIST:
payload.as_value_list.~vector();
break;
case TypeTag::STRING:
payload.as_string.~basic_string();
break;
case TypeTag::SYMINT:
payload.as_symint.~SymInt();
break;
// Manually list out the types so that if a type here is added later and
// not handled the compiler can catch it.
case TypeTag::NONE:
case TypeTag::INT:
case TypeTag::DOUBLE:
case TypeTag::BOOL:
break;
}
}
//
// Constructors, isType(), toType()
//
Value() : tag(TypeTag::NONE) {}
inline bool isNone() const {
return tag == TypeTag::NONE;
}
#define SUPPORT_TRIVIALLY_COPYABLE_TYPE( \
type, type_name, type_tag, member_name) \
explicit Value(type t) : tag(type_tag) { \
payload.u.member_name = t; \
} \
inline bool is##type_name() const { \
return tag == type_tag; \
} \
inline const type& to##type_name() const { \
VK_CHECK_COND( \
is##type_name(), \
"Expected value to have type " #type_name ", got ", \
tag, \
" instead."); \
return payload.u.member_name; \
}
SUPPORT_TRIVIALLY_COPYABLE_TYPE(int64_t, Int, TypeTag::INT, as_int);
SUPPORT_TRIVIALLY_COPYABLE_TYPE(double, Double, TypeTag::DOUBLE, as_double);
SUPPORT_TRIVIALLY_COPYABLE_TYPE(bool, Bool, TypeTag::BOOL, as_bool);
#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
#define SUPPORT_TRIVIALLY_MOVEABLE_TYPE( \
type, type_name, type_tag, member_name) \
explicit Value(type&& t) : tag(type_tag) { \
new (&payload.member_name) type(std::move(t)); \
} \
inline bool is##type_name() const { \
return tag == type_tag; \
} \
inline type& to##type_name() { \
VK_CHECK_COND( \
is##type_name(), \
"Expected value to have type " #type_name ", got ", \
tag, \
" instead."); \
return payload.member_name; \
} \
inline const type& toConst##type_name() const { \
VK_CHECK_COND( \
is##type_name(), \
"Expected value to have type " #type_name ", got ", \
tag, \
" instead."); \
return payload.member_name; \
}
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
api::vTensor,
Tensor,
TypeTag::TENSOR,
as_tensor);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
api::StagingBuffer,
Staging,
TypeTag::STAGING,
as_staging);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
TensorRef,
TensorRef,
TypeTag::TENSORREF,
as_tensorref);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
std::vector<int64_t>,
IntList,
TypeTag::INTLIST,
as_int_list);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
std::vector<double>,
DoubleList,
TypeTag::DOUBLELIST,
as_double_list);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
std::vector<bool>,
BoolList,
TypeTag::BOOLLIST,
as_bool_list);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
std::vector<ValueRef>,
ValueList,
TypeTag::VALUELIST,
as_value_list);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
std::string,
String,
TypeTag::STRING,
as_string);
SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint);
#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
private:
Payload payload;
TypeTag tag;
//
// Utility Functions
//
inline void clearToNone() noexcept {
payload.u.as_int = -1;
tag = TypeTag::NONE;
}
};
} // namespace vkcompute