Fix: sparse_csr_tensor segfaults when crow_indices or col_indices are non-tensors (#56723)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56723
WIP gh-56687
Test Plan: Imported from OSS
Reviewed By: H-Huang
Differential Revision: D27999919
Pulled By: ezyang
fbshipit-source-id: 7eb23c8f45f3c459efe65793caecaa6b67a187c9
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 604542f..0ab500a 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -605,13 +605,13 @@
auto r = parser.parse(args, kwargs, parsed_args);
THPObjectPtr crow_indices_dtype_attr(PyObject_GetAttrString(r.pyobject(CROW_INDICES_ARG), "dtype"));
THPObjectPtr col_indices_dtype_attr(PyObject_GetAttrString(r.pyobject(COL_INDICES_ARG), "dtype"));
- at::ScalarType crow_indices_scalar_type = reinterpret_cast<THPDtype*>(
- crow_indices_dtype_attr.get())->scalar_type;
- at::ScalarType col_indices_scalar_type = reinterpret_cast<THPDtype*>(
- col_indices_dtype_attr.get())->scalar_type;
+ at::ScalarType crow_indices_scalar_type = crow_indices_dtype_attr ? reinterpret_cast<THPDtype*>(
+ crow_indices_dtype_attr.get())->scalar_type : kInt;
+ at::ScalarType col_indices_scalar_type = col_indices_dtype_attr ? reinterpret_cast<THPDtype*>(
+ col_indices_dtype_attr.get())->scalar_type : kInt;
if (r.idx == 0) {
- const int SIZE_ARRAY_ARG = 3, TYPE_INFERENCE_ARG = 4, DEVICE_TYPE_ARG = 7, REQ_GRAD_ARG = 8;
+ const int SIZE_ARRAY_ARG = 3, TYPE_INFERENCE_ARG = 4, DEVICE_TYPE_ARG = 6, REQ_GRAD_ARG = 8;
bool type_inference = r.isNone(TYPE_INFERENCE_ARG);
const auto inferred_options = typeIdWithDefault(r, DEVICE_TYPE_ARG, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(TYPE_INFERENCE_ARG, scalar_type);
@@ -632,7 +632,7 @@
return at::sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
} else if (r.idx == 1) {
- const int TYPE_INFERENCE_ARG = 3, DEVICE_TYPE_ARG = 6, REQ_GRAD_ARG = 7;
+ const int TYPE_INFERENCE_ARG = 3, DEVICE_TYPE_ARG = 5, REQ_GRAD_ARG = 7;
bool type_inference = r.isNone(TYPE_INFERENCE_ARG);
const auto inferred_options = typeIdWithDefault(r, DEVICE_TYPE_ARG, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(TYPE_INFERENCE_ARG, scalar_type);