[AOTI] Add more fallback ops (#126720)
Summary: These ops are either in either unit tests or TorchBench. Fixes https://github.com/pytorch/pytorch/issues/122050
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126720
Approved by: https://github.com/chenyang78
diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh
index 8f2423d..32310e3 100755
--- a/.ci/pytorch/test.sh
+++ b/.ci/pytorch/test.sh
@@ -570,6 +570,8 @@
# Test some models in the cpp wrapper mode
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \
--bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv"
+ TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \
+ --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv"
python benchmarks/dynamo/check_accuracy.py \
--actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \
--expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv"
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 8e20506..a8d61e4 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -2893,7 +2893,6 @@
# minimal arrayref interface only works with CPU; test crashes.
# https://github.com/pytorch/pytorch/issues/122983
"test_multi_device": fail_minimal_arrayref_interface(is_skip=True),
- "test_normal_functional": fail_with_and_without_stack_allocation(),
# undefined symbol: _Z16aoti_torch_dtypeIN3c104HalfEEiv
"test_non_contiguous_output_alias": fail_with_and_without_stack_allocation(
is_skip=True
@@ -2957,10 +2956,6 @@
CUDA_TEST_FAILURES = {
# test_failures, xfail by default, set is_skip=True to skip
"test_large_grid": fail_cuda(),
- "test_normal_functional": fail_abi_compatible_cuda(),
- # There is a double-free issue which will be fixed in another PR
- # no ABI shim fn for torch.sort; remove this when adding one
- "test_triton_kernel_multi_output_arg": fail_abi_compatible_cuda(is_skip=True),
# no runtime checks for non_abi_compatible mode
"test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True),
"test_runtime_checks_complex": fail_non_abi_compatible_cuda(is_skip=True),
@@ -3024,6 +3019,7 @@
"test_empty_graph": fail_minimal_arrayref_interface(is_skip=True),
"test_large_weight": fail_minimal_arrayref_interface(is_skip=True),
"test_large_mmaped_weights": fail_minimal_arrayref_interface(is_skip=True),
+ "test_normal_functional": fail_minimal_arrayref_interface(is_skip=True),
"test_misc_1": fail_minimal_arrayref_interface(is_skip=True),
"test_missing_output": fail_minimal_arrayref_interface(is_skip=True),
"test_model_modified_weights": fail_minimal_arrayref_interface(
diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
index 2c7f05d..770042f 100644
--- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
+++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h
@@ -92,6 +92,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
@@ -100,6 +101,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
@@ -107,6 +109,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
@@ -117,6 +120,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
index 1dceac2..ad41565 100644
--- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
+++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h
@@ -99,6 +99,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0);
@@ -107,6 +108,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
@@ -114,6 +116,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src);
@@ -124,6 +127,7 @@
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
+AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py
index 4a300c3..9d086cc 100644
--- a/torchgen/aoti/fallback_ops.py
+++ b/torchgen/aoti/fallback_ops.py
@@ -88,6 +88,7 @@
"aten.mul.Tensor",
"aten.nanmedian.default",
"aten.native_dropout.default",
+ "aten.normal_functional.default",
"aten.nonzero.default",
"aten.ormqr.default",
"aten._pdist_backward.default",
@@ -99,6 +100,7 @@
"aten.rand.generator",
"aten.randint.default",
"aten.randint.generator",
+ "aten.randint.low",
"aten.randint.low_out",
"aten.randn.default",
"aten.randn.generator",
@@ -106,6 +108,7 @@
"aten.repeat_interleave.Tensor",
"aten.replication_pad1d_backward.default",
"aten.replication_pad2d_backward.default",
+ "aten.reshape.default",
"aten.resize_.default",
"aten.resize_as_.default",
"aten._scaled_dot_product_efficient_attention_backward.default",
diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py
index f123bc8..73152e1 100644
--- a/torchgen/gen_aoti_c_shim.py
+++ b/torchgen/gen_aoti_c_shim.py
@@ -333,6 +333,8 @@
backend_index = backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
+ elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
+ backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
return backend_index
@@ -471,6 +473,7 @@
#include <ATen/{str(dispatch_key)}Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
+#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
{includes}
#endif