operator== for type
diff --git a/aten/src/ATen/templates/Type.cpp b/aten/src/ATen/templates/Type.cpp
index 76480bb..609a755 100644
--- a/aten/src/ATen/templates/Type.cpp
+++ b/aten/src/ATen/templates/Type.cpp
@@ -54,6 +54,10 @@
return tensor({}).fill_(s);
}
+bool Type::operator==(const Type& other) const {
+ return this->ID() == other.ID();
+}
+
${type_method_definitions}
}
diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h
index 994a5f1..0bdbbe9 100644
--- a/aten/src/ATen/templates/Type.h
+++ b/aten/src/ATen/templates/Type.h
@@ -111,6 +111,9 @@
Tensor tensorFromBlob(void * data, IntList sizes);
Tensor tensorFromBlob(void * data, IntList sizes, IntList strides);
Tensor scalarTensor(Scalar s);
+
+ bool operator==(const Type& other) const;
+
// example
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
${type_method_declarations}