blob: 57e83ff57a804b256eced7bacb6f14484592fbb0 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");;
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FUNCTION_TRACE_TYPE_STANDARD_PRIMITIVE_TYPES_H_
#define TENSORFLOW_CORE_FUNCTION_TRACE_TYPE_STANDARD_PRIMITIVE_TYPES_H_
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/core/function/trace_type/standard/trace_type.h"
namespace tensorflow {
namespace trace_type {
// Represents cases where the value is not defined.
class None : public TraceType {
public:
explicit None();
std::unique_ptr<TraceType> clone() const override;
bool is_subtype_of(const TraceType& other) const override;
std::unique_ptr<TraceType> most_specific_common_supertype(
const std::vector<const TraceType*>& others) const override;
std::string to_string() const override;
std::size_t hash() const override;
bool operator==(const TraceType& other) const override;
};
// Represents type hierarchies that have a generic top type.
class Any : public TraceType {
public:
// Passing in absl::nullopt instantiates the top type.
explicit Any(absl::optional<std::unique_ptr<TraceType>> base);
std::unique_ptr<TraceType> clone() const override;
absl::optional<const TraceType*> base() const;
bool is_subtype_of(const TraceType& other) const override;
std::unique_ptr<TraceType> most_specific_common_supertype(
const std::vector<const TraceType*>& others) const override;
std::string to_string() const override;
std::size_t hash() const override;
bool operator==(const TraceType& other) const override;
private:
absl::optional<std::unique_ptr<TraceType>> base_;
};
// TODO(b/231340870): Add support for other types such as tf.dtype.
template <typename T>
class Literal : public TraceType {
public:
explicit Literal(T value) : value_(value) {}
std::unique_ptr<TraceType> clone() const override {
return std::unique_ptr<TraceType>(new Literal(value_));
}
const T& value() const { return value_; }
bool is_subtype_of(const TraceType& other) const override {
return *this == other;
}
std::unique_ptr<TraceType> most_specific_common_supertype(
const std::vector<const TraceType*>& others) const override {
for (const auto& other : others) {
if (*this != *other) return nullptr;
}
return std::unique_ptr<TraceType>(new Literal<T>(value_));
}
std::string to_string() const override;
std::size_t hash() const override { return std::hash<T>()(value_); }
bool operator==(const TraceType& other) const override {
const Literal<T>* casted_other = dynamic_cast<const Literal<T>*>(&other);
if (casted_other == nullptr) return false;
return casted_other->value() == value_;
}
private:
T value_;
};
template <>
inline std::string Literal<int>::to_string() const {
return "Int<" + std::to_string(value_) + ">";
}
template <>
inline std::string Literal<bool>::to_string() const {
return "Bool<" + std::string(value_ ? "True" : "False") + ">";
}
template <>
inline std::string Literal<std::string>::to_string() const {
return "String<" + value_ + ">";
}
} // namespace trace_type
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FUNCTION_TRACE_TYPE_STANDARD_PRIMITIVE_TYPES_H_