|  | import sys | 
|  |  | 
|  | import torch | 
|  |  | 
|  |  | 
|  | def first_arg(x, y): | 
|  | return x[y] | 
|  |  | 
|  |  | 
|  | def second_arg(x, y): | 
|  | return x[:, y] | 
|  |  | 
|  |  | 
|  | def same_pm_one(x, y): | 
|  | return x[y + 1, y - 1] | 
|  |  | 
|  |  | 
|  | def same_pp_one(x, y): | 
|  | return x[y + 1, y + 1] | 
|  |  | 
|  |  | 
|  | def store(x, y, z): | 
|  | x[y + 1, y + 1] = z | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | _, fn_name, dims, dyn_shape, one_size = sys.argv | 
|  | assert fn_name in ("first_arg", "second_arg", "same_pm_one", "same_pp_one", "store") | 
|  | assert one_size in ("True", "False") | 
|  | one_size = one_size == "True" | 
|  | assert dims in ("2", "3") | 
|  | shape_x = [3, 2, 4] if dims == "3" else [3, 2] | 
|  | if one_size: | 
|  | assert ( | 
|  | fn_name == "first_arg" | 
|  | ), "only first_arg can be tested for a special case of 1-size tensor" | 
|  | shape_x[0] = 1 | 
|  | assert dyn_shape in ("True", "False") | 
|  | dynamic_shapes = dyn_shape == "True" | 
|  |  | 
|  | x = torch.randn(shape_x, device="cuda") | 
|  | y = torch.arange(4, device="cuda") | 
|  | fn = vars()[fn_name] | 
|  | fn = torch.compile(dynamic=dynamic_shapes)(fn) | 
|  | if fn_name == "store": | 
|  | shape = (y.numel(),) + x.shape[2:] | 
|  | z = torch.randn(shape, device="cuda") | 
|  | fn(x, y, z) | 
|  | else: | 
|  | fn(x, y) |