[JIT SSA] Added testing for the Cat Op in LazyTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76552
Approved by: https://github.com/Krovatkin
diff --git a/test/cpp/lazy/test_symbolic_shape.cpp b/test/cpp/lazy/test_symbolic_shape.cpp
index f0ce5a3..b2224ae 100644
--- a/test/cpp/lazy/test_symbolic_shape.cpp
+++ b/test/cpp/lazy/test_symbolic_shape.cpp
@@ -128,7 +128,32 @@
expected = {false, true};
EXPECT_EQ(getIsSymbolic(res), expected);
};
-#endif // FBCODE_CAFFE2
+TEST_F(LazyShapeTest, TestCatBasic) {
+ // Basic propagation
+ torch::Tensor a = tensorWithSymbolicShape({2, 2}, {true, false});
+ torch::Tensor b = tensorWithSymbolicShape({2, 2}, {true, false});
+ torch::Tensor c = tensorWithSymbolicShape({2, 2}, {true, false});
+
+ auto res = torch::cat({a, b, c}, 1);
+ std::vector<bool> expected = {true, false};
+ EXPECT_EQ(getIsSymbolic(res), expected);
+
+ torch::Tensor d = tensorWithSymbolicShape({2, 2}, {false, true});
+ res = torch::cat({a, d}, 0);
+ expected = {true, false};
+ EXPECT_EQ(getIsSymbolic(res), expected);
+
+ // Test handling of symbolic dims of inequal sizes, Currently crashes
+ // As we can't handle cases where upper bound dims are not equal
+ /*
+ torch::Tensor e = tensorWithSymbolicShape({2, 2}, {true, false});
+ torch::Tensor f = tensorWithSymbolicShape({2, 3}, {false, true});
+ res = torch::cat({e, f}, 0);
+ expected = {true, false};
+ EXPECT_EQ(getIsSymbolic(res), expected);
+ */
+}
+#endif // FBCODE_CAFFE2
} // namespace lazy
} // namespace torch