blob: 0413b9ff28482441ab13f624f87adf6bd9c01c21 [file] [log] [blame]
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <ostream>
#include <string>
namespace c10 {
class SymNodeImpl;
using SymNode = c10::intrusive_ptr<SymNodeImpl>;
// When you add a method, you also need to edit
// torch/csrc/jit/python/init.cpp
// torch/csrc/utils/python_symnode.h
// c10/core/ConstantSymNodeImpl.h
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
public:
~SymNodeImpl() override = default;
template <typename T>
c10::intrusive_ptr<T> dyn_cast() const {
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
}
// these could be pure virtual when we implement LTC versions
virtual bool is_int() {
TORCH_CHECK(false, "NYI");
};
virtual bool is_bool() {
TORCH_CHECK(false, "NYI");
};
virtual bool is_float() {
TORCH_CHECK(false, "NYI");
};
virtual bool is_nested_int() const {
return false;
};
virtual SymNode add(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sub(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode mul(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode truediv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode pow(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode floordiv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode mod(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode eq(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode ne(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode gt(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode lt(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode le(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode ge(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode ceil() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode floor() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode neg() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_min(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_max(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_or(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_and(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_not() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) {
TORCH_CHECK(false, "NYI");
};
// NB: self is ignored here, only the arguments are used
virtual SymNode is_contiguous(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_contiguous_2d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_contiguous_3d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_strides_2d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_strides_3d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_non_overlapping_and_dense(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode clone() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_float() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode wrap_int(int64_t num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_float(double num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_bool(bool num) {
TORCH_CHECK(false, "NYI");
};
virtual int64_t guard_int(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual bool guard_bool(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual double guard_float(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual bool guard_size_oblivious(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
return guard_bool(file, line);
}
virtual bool expect_true(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
return guard_bool(file, line);
};
virtual bool expect_size(const char* file, int64_t line) {
// No improvement for unbacked SymInts by default, replace this
// with a better implementation!
return ge(wrap_int(0))->guard_bool(file, line);
};
virtual int64_t int_() {
TORCH_CHECK(false, "NYI");
};
virtual bool bool_() {
TORCH_CHECK(false, "NYI");
};
virtual bool has_hint() {
TORCH_CHECK(false, "NYI");
};
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
virtual c10::optional<int64_t> nested_int() {
return c10::nullopt;
}
virtual c10::optional<int64_t> nested_int_coeff() {
return c10::nullopt;
}
virtual c10::optional<int64_t> constant_int() {
return c10::nullopt;
}
virtual c10::optional<bool> constant_bool() {
return c10::nullopt;
}
virtual c10::optional<int64_t> maybe_as_int() {
return c10::nullopt;
}
virtual bool is_constant() {
return false;
}
virtual bool is_symbolic() {
return true;
}
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
}
};
} // namespace c10